diff --git a/build_tools/rocm/BUILD b/build_tools/rocm/BUILD index 6b45a5344be4d..ca3eb53ee943b 100644 --- a/build_tools/rocm/BUILD +++ b/build_tools/rocm/BUILD @@ -34,21 +34,33 @@ filegroup( ) genrule( - name = "san_wrapper_script", - srcs = [":sanitizer_ignore_lists"], - outs = ["san_wrapper.sh"], + name = "exclusive_wrapper_script", + outs = ["exclusive_wrapper.sh"], cmd = """ echo '#!/bin/bash' > $@ - echo 'exec "$$@"' >> $@ + echo 'exec {lock_fd}>/var/lock/gpulock || exit 1' >> $@ + echo 'flock "$$lock_fd"' >> $@ + echo '"$$@"' >> $@ + echo 'return_code=$$?' >> $@ + echo 'flock -u "$$lock_fd"' >> $@ + echo 'exit $$return_code' >> $@ chmod +x $@ """, ) +# this wrapper ensures the test target +# take into account any changes in the ignore list files +sh_binary( + name = "exclusive_local_wrapper", + srcs = [":exclusive_wrapper_script"], + visibility = ["//visibility:public"], +) + # this wrapper ensures the test target # take into account any changes in the ignore list files sh_binary( name = "sanitizer_wrapper", - srcs = [":san_wrapper_script"], + srcs = [":exclusive_wrapper_script"], data = [":sanitizer_ignore_lists"], visibility = ["//visibility:public"], ) diff --git a/build_tools/rocm/lsan_ignore_list.txt b/build_tools/rocm/lsan_ignore_list.txt index b569a1c92a26b..65936e787be05 100644 --- a/build_tools/rocm/lsan_ignore_list.txt +++ b/build_tools/rocm/lsan_ignore_list.txt @@ -3,3 +3,4 @@ leak:libstdc++.so leak:libamdhip64.so leak:libhiprtc.so leak:librccl.so +leak:hwloc_bitmap_alloc diff --git a/build_tools/rocm/platform/BUILD b/build_tools/rocm/platform/BUILD deleted file mode 100644 index 96f0b9e3006de..0000000000000 --- a/build_tools/rocm/platform/BUILD +++ /dev/null @@ -1,16 +0,0 @@ -# Copyright 2025 The OpenXLA Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -package(default_visibility = ["//visibility:public"]) diff --git a/build_tools/rocm/platform/linux_x64/BUILD b/build_tools/rocm/platform/linux_x64/BUILD deleted file mode 100644 index 7d2411ccf905b..0000000000000 --- a/build_tools/rocm/platform/linux_x64/BUILD +++ /dev/null @@ -1,31 +0,0 @@ -load("@rules_cc//cc:defs.bzl", "cc_toolchain", "cc_toolchain_suite") -#load(":cc_toolchain_config.bzl", "cc_toolchain_config") - -package(default_visibility = ["//visibility:public"]) - -platform( - name = "linux_x64", - constraint_values = [ - "@platforms//os:linux", - "@platforms//cpu:x86_64", - "@bazel_tools//tools/cpp:clang", - ], - exec_properties = { - "container-image": "rocm/tensorflow-build@sha256:7cd444ac48657fee2f5087fbda7766266704d3f8fb2299f681952ae4eabed060", - "OSFamily": "Linux", - }, -) - -platform( - name = "linux_x64_gpu", - constraint_values = [ - "@platforms//os:linux", - "@platforms//cpu:x86_64", - "@bazel_tools//tools/cpp:clang", - ], - exec_properties = { - "container-image": "rocm/tensorflow-build@sha256:7cd444ac48657fee2f5087fbda7766266704d3f8fb2299f681952ae4eabed060", - "OSFamily": "Linux", - "Pool": "linux_x64_gpu", - }, -) diff --git a/build_tools/rocm/rocm_xla.bazelrc b/build_tools/rocm/rocm_xla.bazelrc index 5705df5504bfc..09d6a464b2c80 100644 --- a/build_tools/rocm/rocm_xla.bazelrc +++ b/build_tools/rocm/rocm_xla.bazelrc @@ -5,21 +5,22 @@ build:rocm_dev --remote_cache="https://wardite.cluster.engflow.com" build:rocm_rbe --bes_backend="grpcs://wardite.cluster.engflow.com" build:rocm_rbe --bes_results_url="https://wardite.cluster.engflow.com/invocation/" -build:rocm_rbe --host_platform="//build_tools/rocm/platform/linux_x64:linux_x64_gpu" -build:rocm_rbe --extra_execution_platforms="//build_tools/rocm/platform/linux_x64:linux_x64_gpu" -build:rocm_rbe --platforms="//build_tools/rocm/platform/linux_x64:linux_x64_gpu" +build:rocm_rbe --host_platform="@local_config_rocm//rocm:linux_x64" +build:rocm_rbe --extra_execution_platforms="@local_config_rocm//rocm:linux_x64" +build:rocm_rbe --platforms="@local_config_rocm//rocm:linux_x64" build:rocm_rbe --bes_timeout=600s build:rocm_rbe --tls_client_certificate="/tf/certificates/ci-cert.crt" build:rocm_rbe --tls_client_key="/tf/certificates/ci-cert.key" build:rocm_rbe --remote_executor="grpcs://wardite.cluster.engflow.com" build:rocm_rbe --remote_cache="grpcs://wardite.cluster.engflow.com" -build:rocm_rbe --spawn_strategy=local +build:rocm_rbe --spawn_strategy=remote,local build:rocm_rbe --jobs=200 build:rocm_rbe --remote_timeout=3600 build:rocm_rbe --remote_download_minimal build:rocm_rbe --remote_upload_local_results +build:rocm_rbe --grpc_keepalive_time=30s -test:rocm_rbe --strategy=TestRunner=local +test:rocm_rbe --strategy=TestRunner=remote,local build:asan --strip=never build:asan --copt -fsanitize=address @@ -62,7 +63,9 @@ build:xla_sgpu -- \ -//xla/pjrt/distributed:topology_util_test \ -//xla/pjrt/distributed:client_server_test \ -//xla/service/gpu/tests:dynamic_shared_memory_test_amdgpu_any \ --//xla/service/gpu/tests:gpu_cub_sort_test_amdgpu_any +-//xla/service/gpu/tests:gpu_cub_sort_test_amdgpu_any \ +-//xla/tests:iota_test_amdgpu_any \ +-//xla/tests:reduce_window_test_amdgpu_any # TODO: return when it is not flaky! test:xla_mgpu -- \ //xla/tests:collective_ops_e2e_test \ diff --git a/build_tools/rocm/run_jax_ut.sh b/build_tools/rocm/run_jax_ut.sh index 7e0ce5d2b4714..3de8b730aa348 100755 --- a/build_tools/rocm/run_jax_ut.sh +++ b/build_tools/rocm/run_jax_ut.sh @@ -3,23 +3,28 @@ set -e JAX_DIR=$1 -XLA_DIR=$2 +XLA_DIR="/tf/xla" # TODO: later use argument passed from CI job pushd $JAX_DIR -python build/build.py build \ +python3 build/build.py build \ --wheels=jax-rocm-plugin \ --configure_only \ --local_xla_path=${XLA_DIR} \ --python_version=3.12 # TODO: run the tests when they are green -bazel build \ +bazel --bazelrc=${XLA_DIR}/build_tools/rocm/rocm_xla.bazelrc test \ --config=rocm \ + --config=rocm_rbe \ + --disk_cache=/tf/disk_cache/jaxlib-v0.7.1 \ --build_tag_filters=cpu,gpu,-tpu,-config-cuda-only \ --test_tag_filters=cpu,gpu,-tpu,-config-cuda-only \ --action_env=TF_ROCM_AMDGPU_TARGETS=gfx908,gfx90a,gfx942 \ + --test_timeout=920,2400,7200,9600 \ --//jax:build_jaxlib=true \ + --run_under=@xla//build_tools/rocm:exclusive_local_wrapper \ + --action_env=REMOTE_GPU_TESTING=1 \ "//tests/..." popd diff --git a/build_tools/rocm/run_xla.sh b/build_tools/rocm/run_xla.sh index a58f07f04b9e5..2760de15061f0 100755 --- a/build_tools/rocm/run_xla.sh +++ b/build_tools/rocm/run_xla.sh @@ -27,17 +27,17 @@ N_BUILD_JOBS=$(grep -c ^processor /proc/cpuinfo) rocm-smi -i STATUS=$? if [ $STATUS -ne 0 ]; then TF_GPU_COUNT=1; else - TF_GPU_COUNT=$(rocm-smi -i|grep 'Device ID' |grep 'GPU' |wc -l) + TF_GPU_COUNT=$(rocm-smi -i | grep 'Device ID' | grep 'GPU' | wc -l) fi TF_TESTS_PER_GPU=1 N_TEST_JOBS=$(expr ${TF_GPU_COUNT} \* ${TF_TESTS_PER_GPU}) -amdgpuname=(`rocminfo | grep gfx | head -n 1`) +amdgpuname=($(rocminfo | grep gfx | head -n 1)) AMD_GPU_GFX_ID=${amdgpuname[1]} echo "" echo "Bazel will use ${N_BUILD_JOBS} concurrent build job(s) and ${N_TEST_JOBS} concurrent test job(s) for gpu ${AMD_GPU_GFX_ID}." echo "" -export PYTHON_BIN_PATH=`which python3` +export PYTHON_BIN_PATH=$(which python3) export TF_NEED_ROCM=1 export ROCM_PATH="/opt/rocm" @@ -99,12 +99,13 @@ BAZEL_DISK_CACHE_SIZE=100G BAZEL_DISK_CACHE_DIR="/tf/disk_cache/rocm-jaxlib-v0.7.1" mkdir -p ${BAZEL_DISK_CACHE_DIR} if [ ! -d /tf/pkg ]; then - mkdir -p /tf/pkg + mkdir -p /tf/pkg fi SCRIPT_DIR=$(realpath $(dirname $0)) TAG_FILTERS=$($SCRIPT_DIR/rocm_tag_filters.sh),-multigpu,-multi_gpu_h100,requires-gpu-amd,-skip_rocprofiler_sdk,-no_oss,-oss_excluded,-oss_serial +RBE_OPTIONS=() SANITIZER_ARGS=() if [[ $1 == "asan" ]]; then SANITIZER_ARGS+=("--config=asan") @@ -121,6 +122,12 @@ elif [[ $1 == "tsan" ]]; then HostExecuteStartThunkTest* HostExecuteDoneThunkTest* ) + + # tsan tests appear to be flaky in rbe due to the heavy load + # force them to run locally + RBE_OPTIONS+=( + --strategy=TestRunner=local + ) shift fi @@ -139,6 +146,7 @@ bazel --bazelrc=build_tools/rocm/rocm_xla.bazelrc test \ --flaky_test_attempts=3 \ --keep_going \ --local_test_jobs=${N_TEST_JOBS} \ + --repo_env=TF_ROCM_AMDGPU_TARGETS=gfx908,gfx90a,gfx942,gfx1100 \ --test_env=TF_TESTS_PER_GPU=$TF_TESTS_PER_GPU \ --action_env=XLA_FLAGS="--xla_gpu_enable_llvm_module_compilation_parallelism=true --xla_gpu_force_compilation_parallelism=16" \ --run_under=//build_tools/ci:parallel_gpu_execute \ @@ -146,9 +154,10 @@ bazel --bazelrc=build_tools/rocm/rocm_xla.bazelrc test \ --test_env=MIOPEN_FIND_MODE=1 \ --test_filter=-$(IFS=: ; echo "${EXCLUDED_TESTS[*]}") \ "${SANITIZER_ARGS[@]}" \ - "$@" + "$@" \ + "${RBE_OPTIONS[@]}" # clean up bazel disk_cache bazel shutdown \ - --disk_cache=${BAZEL_DISK_CACHE_DIR} \ - --experimental_disk_cache_gc_max_size=${BAZEL_DISK_CACHE_SIZE} + --disk_cache=${BAZEL_DISK_CACHE_DIR} \ + --experimental_disk_cache_gc_max_size=${BAZEL_DISK_CACHE_SIZE} diff --git a/build_tools/rocm/run_xla_multi_gpu.sh b/build_tools/rocm/run_xla_multi_gpu.sh index 81a45711bbb3c..beb1667c567f4 100755 --- a/build_tools/rocm/run_xla_multi_gpu.sh +++ b/build_tools/rocm/run_xla_multi_gpu.sh @@ -60,12 +60,17 @@ if [ ! -d /tf/pkg ]; then fi EXCLUDED_TESTS=( - CollectiveOpsTestE2E.MemcpyP2pLargeMessage - RaggedAllToAllTest/RaggedAllToAllTest.RaggedAllToAll_8GPUs_2ReplicasPerGroups/sync_decomposer - RaggedAllToAllTest/RaggedAllToAllTest.RaggedAllToAll_8GPUs_2ReplicasPerGroups/async_decomposer + # //xla/tests:collective_ops_test_amdgpu_any + RaggedAllToAllTest* + AsyncCollectiveOps* + AsyncMemcpyCollectiveOps* + CollectiveOpsTest* + AllReduceTest* + Fp8CollectiveOpsTest* # //xla/backends/gpu/codegen/triton:fusion_emitter_parametrized_legacy_test_amdgpu_any ElementwiseTestSuiteF32/BinaryElementwiseTest.ElementwiseFusionExecutesCorrectly/f32_atan2 # //xla/tests:collective_ops_e2e_test_amdgpu_any + CollectiveOpsTestE2E.MemcpyP2pLargeMessage CollectiveOpsTestE2EPipelinedNonPipelined.CollectivePipelinerBackward CollectiveOpsTestE2EPipelinedNonPipelined.CollectivePipelinerBackwardStartFromOne # //xla/tools/multihost_hlo_runner:functional_hlo_runner_test @@ -76,6 +81,7 @@ EXCLUDED_TESTS=( SCRIPT_DIR=$(realpath $(dirname $0)) TAG_FILTERS="$($SCRIPT_DIR/rocm_tag_filters.sh)" +RBE_OPTIONS=() SANITIZER_ARGS=() if [[ $1 == "asan" ]]; then SANITIZER_ARGS+=("--run_under=//build_tools/rocm:sanitizer_wrapper") @@ -86,18 +92,10 @@ elif [[ $1 == "tsan" ]]; then SANITIZER_ARGS+=("--run_under=//build_tools/rocm:sanitizer_wrapper") SANITIZER_ARGS+=("--config=tsan") TAG_FILTERS="$TAG_FILTERS,-notsan" - # excluded from tsan - EXCLUDED_TESTS+=( - CollectiveOpsTest* - Fp8CollectiveOpsTest.AllGather_8BitFloat - Fp8CollectiveOpsTest.CollectivePermute_8BitFloat - Fp8CollectiveOpsTest.AllToAll_8BitFloat - AsyncCollectiveOps* - AllReduceTest* - RaggedAllToAllTest* - AsyncCollectiveOps* - AsyncMemcpyCollectiveOps* - RaggedAllToAllTest* + # tsan tests appear to be flaky in rbe due to the heavy load + # force them to run locally + RBE_OPTIONS+=( + --strategy=TestRunner=local ) shift fi @@ -116,14 +114,15 @@ bazel --bazelrc=build_tools/rocm/rocm_xla.bazelrc test \ --test_output=errors \ --flaky_test_attempts=3 \ --keep_going \ - --test_strategy=exclusive \ + --run_under=//build_tools/rocm:exclusive_local_wrapper \ + --repo_env=TF_ROCM_AMDGPU_TARGETS=gfx908,gfx90a,gfx942,gfx1100 \ --action_env=XLA_FLAGS=--xla_gpu_force_compilation_parallelism=16 \ --action_env=XLA_FLAGS=--xla_gpu_enable_llvm_module_compilation_parallelism=true \ --action_env=NCCL_MAX_NCHANNELS=1 \ --test_filter=-$(IFS=: ; echo "${EXCLUDED_TESTS[*]}") \ "${SANITIZER_ARGS[@]}" \ "$@" \ - --strategy=TestRunner=local # execute multigpu tests locally as there is no gpu exclusive protection on rbe + "${RBE_OPTIONS[@]}" # clean up bazel disk_cache bazel shutdown \ diff --git a/third_party/gpus/rocm/BUILD.tpl b/third_party/gpus/rocm/BUILD.tpl index 5df27bceff0c9..825eab1bba6e5 100644 --- a/third_party/gpus/rocm/BUILD.tpl +++ b/third_party/gpus/rocm/BUILD.tpl @@ -136,11 +136,14 @@ cc_library( deps = [":rocm_config"], ) +# workaround to bring tensile files to the same fs layout as expected in the lib +# rocblas assumes that tensile files are located in ../roblas/libraries directory +# hibplatslt assumes that tensile files are located in ../hipblaslt/libraries directory cc_library( name = "rocm_rpath", linkopts = select({ ":build_hermetic": [ - "-Wl,-rpath,%{rocm_toolkit_path}/lib", + "-Wl,-rpath,external/local_config_rocm/rocm/%{rocm_root}/lib", ], ":multiple_rocm_paths": [ "-Wl,-rpath=%{rocm_lib_paths}", @@ -163,7 +166,7 @@ cc_library( cc_library( name = "rocm_hip", - srcs = glob(["%{rocm_root}/lib/libamdhip*.so"]), + srcs = glob(["%{rocm_root}/lib/libamdhip*.so*"]), hdrs = glob(["%{rocm_root}/include/hip/**"]), include_prefix = "rocm", includes = [ @@ -181,7 +184,13 @@ cc_library( # Used by jax_rocm_plugin to minimally link to hip runtime. cc_library( name = "hip_runtime", - srcs = glob(["%{rocm_root}/lib/libamdhip*.so"]), + srcs = glob([ + "%{rocm_root}/lib/libamdhip*.so*", + "%{rocm_root}/lib/libamd_comgr.so*", + "%{rocm_root}/lib/librccl*.so*", + "%{rocm_root}/lib/libhipsparse*.so*", + "%{rocm_root}/lib/libhipsolver*.so*", + ]), hdrs = glob(["%{rocm_root}/include/hip/**"]), include_prefix = "rocm", includes = [ @@ -190,8 +199,9 @@ cc_library( strip_include_prefix = "%{rocm_root}", visibility = ["//visibility:public"], deps = [ + ":hipfft", + ":rocfft", ":rocm_config", - ":rocprofiler_register", ":system_libs", ], ) @@ -213,12 +223,12 @@ cc_library( includes = [ "%{rocm_root}/include", ], - # workaround to bring tensile files to the same fs layout as expected in the lib - # rocblas assumes that tensile files are located in ../roblas/libraries directory - linkopts = ["-Wl,-rpath,local_config_rocm/rocm/rocm_dis/lib"], strip_include_prefix = "%{rocm_root}", visibility = ["//visibility:public"], - deps = [":rocm_config"], + deps = [ + ":rocm_config", + ":rocm_rpath", + ], ) cc_library( @@ -247,8 +257,8 @@ cc_library( cc_library( name = "hiprand", - srcs = glob(["%{rocm_root}/lib/libhiprand*.so*"]), hdrs = glob(["%{rocm_root}/include/hiprand/**"]), + data = glob(["%{rocm_root}/lib/libhiprand*.so*"]), include_prefix = "rocm", includes = [ "%{rocm_root}/include", @@ -257,7 +267,10 @@ cc_library( linkstatic = 1, strip_include_prefix = "%{rocm_root}", visibility = ["//visibility:public"], - deps = [":rocm_config"], + deps = [ + ":rocm_config", + ":rocm_rpath", + ], ) miopen_libs = glob([ @@ -266,7 +279,6 @@ miopen_libs = glob([ cc_library( name = "miopen", - srcs = glob(["%{rocm_root}/lib/libMIOpen*.so*"]), hdrs = glob(["%{rocm_root}/include/miopen/**"]), data = select({ ":build_hermetic": miopen_libs, @@ -277,12 +289,12 @@ cc_library( includes = [ "%{rocm_root}/include", ], - # workaround to bring miopen db files to the same fs layout as expected in the lib - # rocblas assumes that miopen db files are located in ../share/miopen/db directory - linkopts = ["-Wl,-rpath,local_config_rocm/rocm/rocm_dis/lib"], strip_include_prefix = "%{rocm_root}", visibility = ["//visibility:public"], - deps = [":rocm_config"], + deps = [ + ":rocm_config", + ":rocm_rpath", + ], ) cc_library( @@ -336,14 +348,16 @@ cc_library( name = "hipsparse", srcs = glob(["%{rocm_root}/lib/libhipsparse*.so*"]), hdrs = glob(["%{rocm_root}/include/hipsparse/**"]), - data = glob(["%{rocm_root}/lib/libhipsparse*.so*"]), include_prefix = "rocm", includes = [ "%{rocm_root}/include/", ], strip_include_prefix = "%{rocm_root}", visibility = ["//visibility:public"], - deps = [":rocm_config"], + deps = [ + ":rocm_config", + ":rocsparse", + ], ) roctracer_libs = glob(["%{rocm_root}/lib/libroctracer*.so*"]) @@ -375,7 +389,10 @@ cc_library( ], strip_include_prefix = "%{rocm_root}", visibility = ["//visibility:public"], - deps = [":rocm_config"], + deps = [ + ":hip_runtime", + ":rocm_config", + ], ) cc_library( @@ -409,27 +426,28 @@ hipsolver_libs = glob([ cc_library( name = "hipsolver", - srcs = glob(["%{rocm_root}/lib/libhipsolver*.so*"]), - hdrs = glob(["%{rocm_root}/include/hipsolver/**"]), - data = select({ + srcs = select({ ":build_hermetic": hipsolver_libs, ":multiple_rocm_paths": hipsolver_libs, "//conditions:default": [], }), + hdrs = glob(["%{rocm_root}/include/hipsolver/**"]), include_prefix = "rocm", includes = [ "%{rocm_root}/include/", ], strip_include_prefix = "%{rocm_root}", visibility = ["//visibility:public"], - deps = [":rocm_config"], + deps = [ + ":rocm_config", + ":rocm_rpath", + ], ) hipblas_libs = glob(["%{rocm_root}/lib/libhipblas.so*"]) cc_library( name = "hipblas", - srcs = glob(["%{rocm_root}/lib/libhipblas.so*"]), hdrs = glob(["%{rocm_root}/include/hipblas/**"]), data = select({ ":build_hermetic": hipblas_libs, @@ -445,6 +463,7 @@ cc_library( deps = [ ":hipblas-common", ":rocm_config", + ":rocm_rpath", ], ) @@ -477,12 +496,12 @@ cc_library( includes = [ "%{rocm_root}/include/", ], - # workaround to bring tensile files to the same fs layout as expected in the lib - # hibplatslt assumes that tensile files are located in ../hipblaslt/libraries directory - linkopts = ["-Wl,-rpath,local_config_rocm/rocm/rocm_dis/lib"], strip_include_prefix = "%{rocm_root}", visibility = ["//visibility:public"], - deps = [":rocm_config"], + deps = [ + ":rocm_config", + ":rocm_rpath", + ], ) cc_library( @@ -580,3 +599,16 @@ filegroup( srcs = glob(["%{rocm_root}/**"]), visibility = ["//visibility:public"], ) + +platform( + name = "linux_x64", + constraint_values = [ + "@platforms//os:linux", + "@platforms//cpu:x86_64", + "@bazel_tools//tools/cpp:clang", + ], + exec_properties = { + "container-image": "docker://%{rocm_rbe_docker_image}", + "OSFamily": "Linux", + }, +) diff --git a/third_party/gpus/rocm_configure.bzl b/third_party/gpus/rocm_configure.bzl index d60bf2bb16c6e..fbdf891576c15 100644 --- a/third_party/gpus/rocm_configure.bzl +++ b/third_party/gpus/rocm_configure.bzl @@ -52,6 +52,10 @@ _DISTRIBUTION_PATH = "rocm/rocm_dist" _OS = "OS" _ROCM_VERSION = "ROCM_VERSION" +_TF_ROCM_RBE_DOCKER_IMAGE = "TF_ROCM_RBE_DOCKER_IMAGE" +# rocm/tensorflow-build:latest-jammy-python3.11-rocm7.0.2 +_DEFAULT_TF_ROCM_RBE_DOCKER_IMAGE = "rocm/tensorflow-build@sha256:a2672ff2510b369b4a5f034272a518dc93c2e492894e3befaeef19649632ccaa" + _DEFAULT_ROCM_TOOLKIT_PATH = "/opt/rocm" _TF_ROCM_MULTIPLE_PATHS = "TF_ROCM_MULTIPLE_PATHS" _LLVM_PATH = "LLVM_PATH" @@ -655,6 +659,7 @@ def _create_local_rocm_repository(repository_ctx): repository_dict = { "%{rocm_root}": rocm_toolkit_path, "%{rocm_toolkit_path}": str(repository_ctx.path(rocm_config.rocm_toolkit_path)), + "%{rocm_rbe_docker_image}": repository_ctx.os.environ.get(_TF_ROCM_RBE_DOCKER_IMAGE, _DEFAULT_TF_ROCM_RBE_DOCKER_IMAGE), } is_rocm_clang = _use_rocm_clang(repository_ctx) @@ -873,6 +878,8 @@ _ENVIRONS = [ _TF_ROCM_AMDGPU_TARGETS, _OS, _ROCM_VERSION, + _TF_ROCM_RBE_DOCKER_IMAGE, + _TF_ROCM_MULTIPLE_PATHS, ] remote_rocm_configure = repository_rule( diff --git a/xla/backends/gpu/codegen/triton/support.cc b/xla/backends/gpu/codegen/triton/support.cc index 24809a1589a12..0791681fe12e5 100644 --- a/xla/backends/gpu/codegen/triton/support.cc +++ b/xla/backends/gpu/codegen/triton/support.cc @@ -97,7 +97,6 @@ absl::flat_hash_set TritonSupportedUnaryElementwiseOps( if (element_type != PrimitiveType::F8E5M2 && element_type != PrimitiveType::F8E4M3FN && - element_type != PrimitiveType::F8E4M3B11FNUZ && element_type != PrimitiveType::F8E5M2FNUZ && element_type != PrimitiveType::F8E4M3FNUZ) { ret.insert(HloOpcode::kNegate); @@ -147,9 +146,11 @@ CodegenDecision IsTritonSupportedConversion( return error_message(); } - bool is_f8_conversion = - any_is(PrimitiveType::F8E4M3FN) && any_is(PrimitiveType::F8E5M2); - bool is_f8 = any_is(PrimitiveType::F8E4M3FN) || any_is(PrimitiveType::F8E5M2); + auto supported_fp8_types = {F8E4M3FN, F8E5M2, F8E4M3FNUZ, F8E5M2FNUZ}; + bool is_input_fp8 = absl::c_linear_search(supported_fp8_types, input); + bool is_output_fp8 = absl::c_linear_search(supported_fp8_types, output); + bool is_f8_conversion = is_input_fp8 && is_output_fp8; + bool is_f8 = is_input_fp8 || is_output_fp8; bool is_f16_or_f32 = any_is(PrimitiveType::F16) || any_is(PrimitiveType::BF16) || any_is(PrimitiveType::F32); @@ -179,7 +180,6 @@ absl::flat_hash_set TritonSupportedBinaryElementwiseOps( if (element_type == PrimitiveType::S4 || element_type == PrimitiveType::U16 || element_type == PrimitiveType::F8E5M2 || element_type == PrimitiveType::F8E4M3FN || - element_type == PrimitiveType::F8E4M3B11FNUZ || element_type == PrimitiveType::F8E5M2FNUZ || element_type == PrimitiveType::F8E4M3FNUZ) { return {}; @@ -217,6 +217,7 @@ absl::flat_hash_set TritonSupportedBinaryElementwiseOps( ret.insert(HloOpcode::kAtan2); ret.insert(HloOpcode::kPower); ret.insert(HloOpcode::kRemainder); + ret.insert(HloOpcode::kDivide); } return ret; @@ -231,7 +232,6 @@ absl::flat_hash_set TritonSupportedTernaryElementwiseOps( if (element_type == PrimitiveType::F8E5M2 || element_type == PrimitiveType::F8E4M3FN || - element_type == PrimitiveType::F8E4M3B11FNUZ || element_type == PrimitiveType::F8E5M2FNUZ || element_type == PrimitiveType::F8E4M3FNUZ) { return {HloOpcode::kSelect}; @@ -263,8 +263,8 @@ CodegenDecision CanTritonHandleReduce( if (reduce.shape().element_type() == PrimitiveType::F8E4M3FN || reduce.shape().element_type() == PrimitiveType::F8E5M2 || reduce.shape().element_type() == PrimitiveType::F8E5M2FNUZ || - reduce.shape().element_type() == PrimitiveType::F8E4M3FNUZ || - reduce.shape().element_type() == PrimitiveType::F8E4M3B11FNUZ) { + reduce.shape().element_type() == PrimitiveType::F8E4M3FNUZ /*|| + reduce.shape().element_type() == PrimitiveType::F8E4M3B11FNUZ*/) { return CodegenDecision::Forbid( "F8E4M3FN and F8E5M2 are not supported for reductions."); } @@ -358,15 +358,15 @@ CodegenDecision AreTypesSupportedByAlgUnsetDot( } } - if (input_type == F8E4M3B11FNUZ || result_type == F8E4M3B11FNUZ || - input_type == F64) { + if (input_type == F8E4M3B11FNUZ || result_type == F8E4M3B11FNUZ) { if (std::holds_alternative(gpu_version)) { return CodegenDecision::Forbid( "Dot operation for F8E4M3B11FNUZ is not supported on ROCM."); } } - auto supported_float_types = {BF16, F16, F32, F64, F8E5M2}; + auto supported_float_types = {BF16, F16, F32, F8E4M3FN, F8E5M2, F8E4M3FNUZ, + F8E5M2FNUZ}; if (absl::c_linear_search(supported_float_types, input_type)) { return CodegenDecision::Allow(); } @@ -375,13 +375,15 @@ CodegenDecision AreTypesSupportedByAlgUnsetDot( return CodegenDecision::Allow(); } - auto partially_supported_signed_types = {S4, S8, S16, S32, S64}; + auto partially_supported_signed_types = {S8, S16, S32, S64}; if (absl::c_linear_search(partially_supported_signed_types, input_type)) { - if (absl::c_linear_search(partially_supported_signed_types, result_type)) { + if ((absl::c_linear_search(partially_supported_signed_types, result_type) && + !std::holds_alternative(gpu_version))) { return CodegenDecision::Forbid( "Dot operation does not support these signed integer types."); } - if (primitive_util::IsFloatingPointType(result_type)) { + if (primitive_util::IsFloatingPointType(result_type) && + !std::holds_alternative(gpu_version)) { return CodegenDecision::Forbid( "Dot operation does not support floating point input and signed " "integer result types."); @@ -435,9 +437,9 @@ CodegenDecision AreDotAlgorithmInputAndOutputConversionsSupported( return forbid("Unsupported BF16 on GPUs before Blackwell"); } - if (allowed_operands_types_or->front() == PrimitiveType::F64 && + if (algorithm == PrecisionConfig::ALG_DOT_F64_F64_F64 && std::holds_alternative(gpu_version)) { - return forbid("Unsupported result conversion"); + return forbid("Unsupported BF16 on Rocm"); } if (allowed_operands_types_or->size() != 1) { @@ -679,6 +681,13 @@ CodegenDecision IsTritonSupportedInstructionImpl( return CodegenDecision::Forbid( "dynamic slice is supported but not enabled yet"); case HloOpcode::kBitcast: + if (ShapeUtil::ElementsIn(instr.operand(0)->shape()) != + ShapeUtil::ElementsIn(instr.shape())) { + return CodegenDecision::Forbid( + "only bitcasts with the same number of elements are supported"); + } + return CodegenDecision(instr.shape().element_type() != S4, + "S4 is not supported."); case HloOpcode::kBroadcast: case HloOpcode::kReshape: case HloOpcode::kSlice: diff --git a/xla/backends/gpu/codegen/triton/support_test.cc b/xla/backends/gpu/codegen/triton/support_test.cc index 10f302c208aae..d60cd5eea648e 100644 --- a/xla/backends/gpu/codegen/triton/support_test.cc +++ b/xla/backends/gpu/codegen/triton/support_test.cc @@ -527,16 +527,17 @@ ENTRY triton_computation { any_is(PrimitiveType::F8E4M3FN) && any_is(PrimitiveType::F8E5M2); } - // Crashes due to unsupported/unspecified rounding mode. - crashes_on_failure |= (data_type_in == PrimitiveType::F64 && - (data_type_out == PrimitiveType::F8E4M3FN || - data_type_out == PrimitiveType::F8E5M2)); - - // Crashes due to unsupported conversion. - crashes_on_failure |= (data_type_out == PrimitiveType::F64 && - (data_type_in == PrimitiveType::F8E4M3FN || - data_type_in == PrimitiveType::F8E5M2)); - + if (std::holds_alternative(cc)) { + // Crashes due to unsupported/unspecified rounding mode. + crashes_on_failure |= (data_type_in == PrimitiveType::F64 && + (data_type_out == PrimitiveType::F8E4M3FN || + data_type_out == PrimitiveType::F8E5M2)); + + // Crashes due to unsupported conversion. + crashes_on_failure |= (data_type_out == PrimitiveType::F64 && + (data_type_in == PrimitiveType::F8E4M3FN || + data_type_in == PrimitiveType::F8E5M2)); + } RunSupportTest( std::move(ti), /*output_tile_sizes=*/{1, 32}, cc, crashes_on_failure ? ExpectedFailMode::kCrash : ExpectedFailMode::kFail); @@ -577,15 +578,22 @@ ENTRY triton_computation { data_type, opcode)); ExpectedFailMode fail_mode = ExpectedFailMode::kFail; - if (opcode == HloOpcode::kDivide && - (data_type == PrimitiveType::BF16 || data_type == PrimitiveType::F16 || - data_type == PrimitiveType::F8E5M2 || - data_type == PrimitiveType::F8E4M3FN || - data_type == PrimitiveType::F8E4M3B11FNUZ || - data_type == PrimitiveType::F8E5M2FNUZ || - data_type == PrimitiveType::F8E4M3FNUZ)) { - fail_mode = ExpectedFailMode::kCrash; - }; + if (std::holds_alternative(cc)) { + if (opcode == HloOpcode::kDivide && + (data_type == PrimitiveType::BF16 || data_type == PrimitiveType::F16 || + data_type == PrimitiveType::F8E5M2 || + data_type == PrimitiveType::F8E4M3FN)) { + fail_mode = ExpectedFailMode::kCrash; + } + } else { + if (((opcode == HloOpcode::kMaximum || opcode == HloOpcode::kMinimum) && + (data_type == PrimitiveType::F8E5M2 || + data_type == PrimitiveType::F8E4M3FN || + data_type == PrimitiveType::F8E5M2FNUZ || + data_type == PrimitiveType::F8E4M3FNUZ))) { + fail_mode = ExpectedFailMode::kFailOrCrash; + } + } RunSupportTest(std::move(ti), /*output_tile_sizes=*/{1, 32}, cc, fail_mode); } @@ -614,14 +622,21 @@ ENTRY triton_computation { data_type, opcode)); ExpectedFailMode fail_mode = ExpectedFailMode::kFail; - if (opcode == HloOpcode::kDivide && - (data_type == PrimitiveType::BF16 || data_type == PrimitiveType::F16 || - data_type == PrimitiveType::F8E5M2 || - data_type == PrimitiveType::F8E4M3FN || - data_type == PrimitiveType::F8E4M3B11FNUZ || - data_type == PrimitiveType::F8E5M2FNUZ || - data_type == PrimitiveType::F8E4M3FNUZ)) { - fail_mode = ExpectedFailMode::kCrash; + if (std::holds_alternative(cc)) { + if (opcode == HloOpcode::kDivide && + (data_type == PrimitiveType::BF16 || data_type == PrimitiveType::F16 || + data_type == PrimitiveType::F8E5M2 || + data_type == PrimitiveType::F8E4M3FN)) { + fail_mode = ExpectedFailMode::kCrash; + } + } else { + if (((opcode == HloOpcode::kMaximum || opcode == HloOpcode::kMinimum) && + (data_type == PrimitiveType::F8E5M2 || + data_type == PrimitiveType::F8E4M3FN || + data_type == PrimitiveType::F8E5M2FNUZ || + data_type == PrimitiveType::F8E4M3FNUZ))) { + fail_mode = ExpectedFailMode::kFailOrCrash; + } } RunSupportTest(std::move(ti), /*output_tile_sizes=*/{}, cc, fail_mode); @@ -675,7 +690,20 @@ ENTRY triton_computation { TF_ASSERT_OK_AND_ASSIGN( TestedInstruction ti, ParseTemplateAndGetInstruction(hlo_text, data_type, opcode)); - RunSupportTest(std::move(ti), /*output_tile_sizes=*/{1, 32}, cc); + + bool skip_failure_branch_to_avoid_crash = false; + if (std::holds_alternative(cc)) { + skip_failure_branch_to_avoid_crash = + (opcode == HloOpcode::kClamp || opcode == HloOpcode::kSelect) && + (data_type == PrimitiveType::F8E5M2 || + data_type == PrimitiveType::F8E4M3FN || + data_type == PrimitiveType::F8E5M2FNUZ || + data_type == PrimitiveType::F8E4M3FNUZ); + } + + RunSupportTest(std::move(ti), /*output_tile_sizes=*/{1, 32}, cc, + skip_failure_branch_to_avoid_crash ? ExpectedFailMode::kFailOrCrash + : ExpectedFailMode::kFail); } constexpr std::array kTestedOpsTernaryElementwise = {HloOpcode::kSelect, @@ -718,7 +746,9 @@ ENTRY triton_computation { TestedInstruction ti, ParseTemplateAndGetInstruction(kHloTestTemplate, data_type, opcode)); bool crashes_on_failure = data_type == PrimitiveType::F8E4M3FN || - data_type == PrimitiveType::F8E5M2; + data_type == PrimitiveType::F8E5M2 || + data_type == PrimitiveType::F8E5M2FNUZ || + data_type == PrimitiveType::F8E4M3FNUZ; RunSupportTest( std::move(ti), /*output_tile_sizes=*/{1}, cc, crashes_on_failure ? ExpectedFailMode::kCrash : ExpectedFailMode::kFail); @@ -742,7 +772,7 @@ ENTRY triton_computation { ParseTemplateAndGetInstruction(kHloTestTemplate, F32, HloOpcode::kReduce)); RunSupportTest(std::move(ti), /*output_tile_sizes=*/{3, 4}, - se::CudaComputeCapability::Ampere()); + CudaAmpereOrRocm()); } TEST_P( @@ -789,7 +819,9 @@ ENTRY triton_computation { ParseTemplateAndGetInstruction(kHloTestTemplate, data_type, opcode)); bool crashes_on_failure = data_type == PrimitiveType::F8E4M3FN || - data_type == PrimitiveType::F8E5M2; + data_type == PrimitiveType::F8E5M2 || + data_type == PrimitiveType::F8E5M2FNUZ || + data_type == PrimitiveType::F8E4M3FNUZ; RunSupportTest( std::move(ti), /*output_tile_sizes=*/{1}, cc, crashes_on_failure ? ExpectedFailMode::kCrash : ExpectedFailMode::kFail); @@ -825,7 +857,7 @@ ENTRY triton_computation { } TEST_F(ReduceTest, ReduceWithNonConstReduceValueIsSupportedWithTriton) { - const se::GpuComputeCapability cc = se::CudaComputeCapability::Ampere(); + const se::GpuComputeCapability cc = CudaAmpereOrRocm(); const std::string kHloTestTemplate = R"( add { Arg_0 = $0[] parameter(0) @@ -905,12 +937,15 @@ ENTRY triton_computation { // TODO(b/361526623): Reduce the cases where emitter crashes. ExpectedFailMode fail_mode = ExpectedFailMode::kFail; - if (opcode == HloOpcode::kDivide && (data_type == BF16 || data_type == F16)) { + if (opcode == HloOpcode::kDivide && (data_type == BF16 || + data_type == F16)) { fail_mode = ExpectedFailMode::kCrash; } - if (data_type == F8E4M3FN || data_type == F8E5M2) { + if (data_type == F8E4M3FN || data_type == F8E5M2 || data_type == PrimitiveType::F8E5M2FNUZ || + data_type == PrimitiveType::F8E4M3FNUZ) { fail_mode = ExpectedFailMode::kFailOrCrash; } + RunSupportTest(std::move(ti), /*output_tile_sizes=*/{1}, cc, fail_mode); } @@ -1854,12 +1889,12 @@ TEST_P(DotTypesTest, Dot) { fail_mode = ExpectedFailMode::kFailOrCrash; } } - if (absl::c_linear_search(std::vector{F8E5M2FNUZ, F8E4M3FNUZ, F8E4M3FN}, input_type) || - absl::c_linear_search(std::vector{F8E5M2FNUZ, F8E4M3FNUZ, F8E4M3FN}, result_type) || - input_type == F64) { - if (std::holds_alternative(cc)) { - // Hits llvm::report_fatal_error during Triton compilation. - fail_mode = ExpectedFailMode::kFailOrCrash; + if (std::holds_alternative(cc)) { + if (absl::c_linear_search(std::vector{F8E5M2FNUZ, F8E4M3FNUZ, F8E4M3FN}, input_type) || + absl::c_linear_search(std::vector{F8E5M2FNUZ, F8E4M3FNUZ, F8E4M3FN}, result_type) || + input_type == F64) { + // Hits llvm::report_fatal_error during Triton compilation. + fail_mode = ExpectedFailMode::kFailOrCrash; } } @@ -2169,49 +2204,6 @@ ENTRY triton_computation { CudaAmpereOrRocm()); } -TEST_F(DotTest, SparsityConfiguration) { - // Note that support rejects this HLO as u16 is not supported. - const std::string kHloTestTemplate = R"( -flhs { - ROOT result = $0[128,128] parameter(0) -} - -frhs { - ROOT result = $0[256,512] parameter(0) -} - -ENTRY triton_computation { - p0 = $0[128,128] parameter(0) - p1 = $0[256,512] parameter(1) - lhs = $0[128,128] fusion(p0), kind=kCustom, calls=flhs, backend_config={ - "fusion_backend_config":{ - "kind":"__triton_nested_gemm_fusion", "block_level_fusion_config":{ - "output_tiles":[{"sizes":["16", "64"]}] - } - } - } - rhs = $0[256,512] fusion(p1), kind=kCustom, calls=frhs, backend_config={ - "fusion_backend_config":{ - "kind":"__triton_nested_gemm_fusion", "block_level_fusion_config":{ - "output_tiles":[{"sizes":["64", "32"]}] - } - } - } - meta = u16[128,16] parameter(2) - ROOT result = $0[128,512] dot(lhs, rhs, meta), - lhs_contracting_dims={1}, - rhs_contracting_dims={0}, - sparsity=L.1@2:4 -} -)"; - TF_ASSERT_OK_AND_ASSIGN( - TestedInstruction ti, - ParseTemplateAndGetInstruction(kHloTestTemplate, F32, HloOpcode::kDot, - /* use_nested_gemm_fusions=*/true)); - RunSupportTest(std::move(ti), /*output_tile_sizes=*/{16, 32}, - CudaAmpereOrRocm()); -} - class DotPrecisionTest : public DotTest, public ::testing::WithParamInterface< @@ -2272,14 +2264,9 @@ ENTRY triton_computation { if (absl::c_linear_search(std::vector{F8E5M2, F8E4M3FN, S8}, data_type)) { fail_mode = ExpectedFailMode::kFailOrCrash; } - if (std::holds_alternative(cc)) { - if (data_type == F64) { - fail_mode = ExpectedFailMode::kFailOrCrash; - } - } if (std::holds_alternative(cc)) { if (absl::c_linear_search(std::vector{F8E4M3FNUZ, F8E5M2FNUZ, F8E4M3FN, - F64}, data_type)) { + S8, S16, S32, S64}, data_type)) { fail_mode = ExpectedFailMode::kFailOrCrash; } } @@ -2380,7 +2367,7 @@ ENTRY triton_computation { if (std::holds_alternative(cc)) { if (absl::c_linear_search(std::vector{F8E4M3FN, F8E5M2FNUZ, F8E4M3FNUZ, F64}, data_type) || - (absl::c_linear_search(std::vector{F16, S64, S32, S16, BF16, F32}, + (absl::c_linear_search(std::vector{S64, S32, S16, BF16, F16, F32}, data_type) && algorithm == xla::PrecisionConfig::ALG_DOT_F64_F64_F64)) { fail_mode = ExpectedFailMode::kFailOrCrash; diff --git a/xla/backends/gpu/runtime/BUILD b/xla/backends/gpu/runtime/BUILD index 9f25e057c5e2b..f0ce9fad75a68 100644 --- a/xla/backends/gpu/runtime/BUILD +++ b/xla/backends/gpu/runtime/BUILD @@ -1812,6 +1812,8 @@ xla_test( backend_tags = { "gpu": [ "multi_gpu_h100", + "multi_gpu", + "local", "no_oss", ], }, diff --git a/xla/service/gpu/gpu_device_info_for_tests.cc b/xla/service/gpu/gpu_device_info_for_tests.cc index 6d7607d991f09..2edda53dda49f 100644 --- a/xla/service/gpu/gpu_device_info_for_tests.cc +++ b/xla/service/gpu/gpu_device_info_for_tests.cc @@ -80,7 +80,7 @@ stream_executor::DeviceDescription TestGpuDeviceInfo::AMDMI210DeviceInfo() { b.set_threads_per_block_limit(1024); b.set_threads_per_warp(64); b.set_shared_memory_per_block(64 * 1024); - b.set_shared_memory_per_block_optin(0); + b.set_shared_memory_per_block_optin(64 * 1024); b.set_shared_memory_per_core(64 * 1024); b.set_threads_per_core_limit(2048); b.set_core_count(104); diff --git a/xla/tests/BUILD b/xla/tests/BUILD index c966f08b2614c..b9e3b5e7f4424 100644 --- a/xla/tests/BUILD +++ b/xla/tests/BUILD @@ -1009,6 +1009,7 @@ xla_test( srcs = ["dot_operation_test.cc"], shard_count = 20, tags = [ + "local", # TODO: remove when remote execution is fixed "optonly", ], deps = [ @@ -1242,6 +1243,7 @@ xla_test( }, shard_count = 20, tags = [ + "local", "optonly", ], deps = [ @@ -1969,6 +1971,7 @@ cc_library( # This is set intentionally as to avoid the default behavior of the TSL # `cc_library` definition that is used in this file. compatible_with = [], + tags = ["local"], # TODO: remove when remote execution is fixed deps = [ ":client_library_test_base", ":hlo_test_base", @@ -2622,6 +2625,7 @@ xla_test( backend_tags = { "gpu": [ "multi_gpu", + "local", "no_oss", ], "cpu": [ @@ -2721,6 +2725,7 @@ xla_test( backend_tags = { "gpu": [ "multi_gpu", + "local", "no_oss", ], }, @@ -2801,7 +2806,10 @@ xla_test( ], }, backends = ["gpu"], - tags = ["test_migrated_to_hlo_runner_pjrt"], + tags = [ + "test_migrated_to_hlo_runner_pjrt", + "local", + ], deps = [ ":hlo_pjrt_test_base", ":literal_test_util", @@ -3356,6 +3364,7 @@ xla_test( }, shard_count = 50, tags = [ + "local", # TODO: remove when remote execution is fixed "test_migrated_to_hlo_runner_pjrt", ], deps = [ diff --git a/xla/tools/multihost_hlo_runner/BUILD b/xla/tools/multihost_hlo_runner/BUILD index fdd366b0860c3..f5f7f7cb63f38 100644 --- a/xla/tools/multihost_hlo_runner/BUILD +++ b/xla/tools/multihost_hlo_runner/BUILD @@ -232,6 +232,8 @@ xla_test( backend_tags = { "gpu": [ "multi_gpu_h100", + "multi_gpu", + "local", "no_oss", "nomsan", ], diff --git a/xla/tsl/platform/default/build_config_root.bzl b/xla/tsl/platform/default/build_config_root.bzl index 5b71a829b2a02..e69e79d541a98 100644 --- a/xla/tsl/platform/default/build_config_root.bzl +++ b/xla/tsl/platform/default/build_config_root.bzl @@ -19,6 +19,14 @@ GPU_TEST_PROPERTIES = { "Pool": "gpu-pool", } +ROCM_SINGLE_GPU_TEST_PROPERTIES = { + "test.Pool": "linux_x64_gpu", +} + +ROCM_MULTI_GPU_TEST_PROPERTIES = { + "test.Pool": "linux_x64_multigpu", +} + def tf_gpu_tests_tags(): """Gets tags for TensorFlow GPU tests based on the configured environment. @@ -38,9 +46,23 @@ def tf_gpu_tests_tags(): def tf_cuda_tests_tags(): return tf_gpu_tests_tags() +def tf_has_tag(kwargs, tag): + return ("tags" in kwargs and kwargs["tags"] != None and tag in kwargs["tags"]) + def tf_exec_properties(kwargs): - if ("tags" in kwargs and kwargs["tags"] != None and - "remote-gpu" in kwargs["tags"]): + """Gets execution_properties for TensorFlow GPU tests based on the provided tags. + + Args: + kwargs: all arguments of the xla test target + Returns: + execution_properties with the execution pool names for rbe. + """ + if is_rocm_configured(): + if tf_has_tag(kwargs, "multi_gpu"): + return ROCM_MULTI_GPU_TEST_PROPERTIES + if tf_has_tag(kwargs, "gpu"): + return ROCM_SINGLE_GPU_TEST_PROPERTIES + elif tf_has_tag(kwargs, "remote-gpu"): return GPU_TEST_PROPERTIES return {}