Skip to content

Commit 85da4d6

Browse files
EetusjoGoogle-ML-Automation
authored andcommitted
PR #35575: [ROCm] Add autotuner rocblas/hipblaslt backends
Imported from GitHub PR #35575 Adds ROCm-specific autotuner backends. Essentially copies the existing cublas/lt backends and renames, with minor changes. Comment from another PR as context, with the ask to separate CUDA/ROCm autotuner backends: #35280 (comment) Copybara import of the project: -- e5a5496 by Eetu Sjöblom <eetu.sjoblom@amd.com>: copy cublas/cublaslt backends to create rocblas/hipblaslt ones -- a09edaf by Eetu Sjöblom <eetu.sjoblom@amd.com>: Pass cc to GetBlasComputationType Merging this change closes #35575 FUTURE_COPYBARA_INTEGRATE_REVIEW=#35575 from ROCm:ci_rocm_autotuner_backends f9e8b77 PiperOrigin-RevId: 853116778
1 parent 4c478a9 commit 85da4d6

File tree

10 files changed

+1113
-14
lines changed

10 files changed

+1113
-14
lines changed

xla/backends/gpu/autotuner/BUILD

Lines changed: 150 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -630,6 +630,155 @@ cc_library(
630630
alwayslink = True,
631631
)
632632

633+
cc_library(
634+
name = "rocblas",
635+
srcs = ["rocblas.cc"],
636+
hdrs = ["rocblas.h"],
637+
tags = [
638+
"gpu",
639+
"rocm-only",
640+
],
641+
deps = [
642+
":gpu_codegen_backend",
643+
"//xla:autotuning_proto_cc",
644+
"//xla:shape_util",
645+
"//xla:xla_proto_cc",
646+
"//xla/backends/autotuner:codegen_backend",
647+
"//xla/hlo/ir:hlo",
648+
"//xla/hlo/utils:hlo_query",
649+
"//xla/service:compiler",
650+
"//xla/service:hlo_cost_analysis",
651+
"//xla/service/gpu:backend_configs_cc",
652+
"//xla/service/gpu:cublas_cudnn",
653+
"//xla/service/gpu:matmul_utils",
654+
"//xla/service/gpu/autotuning:redzone_buffers",
655+
"//xla/service/gpu/transforms:dot_algorithm_rewriter",
656+
"//xla/service/gpu/transforms:gemm_rewriter",
657+
"//xla/stream_executor:blas",
658+
"//xla/stream_executor:device_address",
659+
"//xla/stream_executor:device_address_allocator",
660+
"//xla/stream_executor:device_description",
661+
"//xla/stream_executor:semantic_version",
662+
"//xla/stream_executor:stream_executor_h",
663+
"//xla/stream_executor:stream_executor_memory_allocator",
664+
"//xla/stream_executor/gpu:gpu_blas_lt",
665+
"//xla/stream_executor/gpu:redzone_allocator",
666+
"//xla/stream_executor/rocm:rocblas_plugin",
667+
"//xla/tools:hlo_decomposer_lib",
668+
"//xla/tsl/lib/gtl:iterator_range",
669+
"//xla/tsl/platform:errors",
670+
"//xla/tsl/platform:statusor",
671+
"@com_google_absl//absl/algorithm:container",
672+
"@com_google_absl//absl/log",
673+
"@com_google_absl//absl/status",
674+
"@com_google_absl//absl/status:statusor",
675+
"@com_google_absl//absl/strings:string_view",
676+
],
677+
)
678+
679+
cc_library(
680+
name = "hipblaslt",
681+
srcs = ["hipblaslt.cc"],
682+
hdrs = ["hipblaslt.h"],
683+
tags = [
684+
"gpu",
685+
"rocm-only",
686+
],
687+
deps = [
688+
":gpu_codegen_backend",
689+
"//xla:autotuning_proto_cc",
690+
"//xla:shape_util",
691+
"//xla:util",
692+
"//xla:xla_proto_cc",
693+
"//xla/backends/autotuner:codegen_backend",
694+
"//xla/hlo/ir:hlo",
695+
"//xla/service:compiler",
696+
"//xla/service/gpu:backend_configs_cc",
697+
"//xla/service/gpu:cublas_cudnn",
698+
"//xla/service/gpu:matmul_utils",
699+
"//xla/stream_executor:blas",
700+
"//xla/stream_executor:device_description",
701+
"//xla/stream_executor:stream",
702+
"//xla/stream_executor:stream_executor_h",
703+
"//xla/stream_executor/gpu:gpu_blas_lt",
704+
"//xla/stream_executor/rocm:amdhipblaslt_plugin",
705+
"//xla/tsl/platform:errors",
706+
"//xla/tsl/platform:statusor",
707+
"@com_google_absl//absl/log",
708+
"@com_google_absl//absl/status",
709+
"@com_google_absl//absl/status:statusor",
710+
"@com_google_absl//absl/strings:string_view",
711+
],
712+
)
713+
714+
xla_test(
715+
name = "hipblaslt_test",
716+
srcs = ["hipblaslt_test.cc"],
717+
backends = [
718+
"amdgpu_any",
719+
],
720+
tags = [
721+
"gpu",
722+
"rocm-only",
723+
],
724+
deps = [
725+
":hipblaslt",
726+
"//xla:autotuning_proto_cc",
727+
"//xla:xla_proto_cc",
728+
"//xla/backends/autotuner:codegen_backend",
729+
"//xla/hlo/ir:hlo",
730+
"//xla/hlo/testlib:filecheck",
731+
"//xla/hlo/testlib:hlo_hardware_independent_test_base",
732+
"//xla/service:compiler",
733+
"//xla/service:executable",
734+
"//xla/service:platform_util",
735+
"//xla/service/gpu:amdgpu_compiler_impl",
736+
"//xla/stream_executor:blas",
737+
"//xla/stream_executor:device_description_proto_cc",
738+
"//xla/stream_executor:stream_executor_h",
739+
"//xla/tsl/lib/core:status_test_util",
740+
"//xla/tsl/platform:statusor",
741+
"@com_google_absl//absl/status",
742+
"@com_google_absl//absl/status:status_matchers",
743+
"@com_google_absl//absl/status:statusor",
744+
"@com_google_googletest//:gtest_main",
745+
],
746+
)
747+
748+
xla_test(
749+
name = "rocblas_test",
750+
srcs = ["rocblas_test.cc"],
751+
backends = [
752+
"amdgpu_any",
753+
],
754+
tags = [
755+
"gpu",
756+
"rocm-only",
757+
],
758+
deps = [
759+
":rocblas",
760+
"//xla:autotuning_proto_cc",
761+
"//xla:xla_proto_cc",
762+
"//xla/backends/autotuner:codegen_backend",
763+
"//xla/hlo/ir:hlo",
764+
"//xla/hlo/testlib:filecheck",
765+
"//xla/hlo/testlib:hlo_hardware_independent_test_base",
766+
"//xla/service:compiler",
767+
"//xla/service:executable",
768+
"//xla/service:platform_util",
769+
"//xla/service/gpu:amdgpu_compiler_impl",
770+
"//xla/stream_executor:blas",
771+
"//xla/stream_executor:device_description_proto_cc",
772+
"//xla/stream_executor:stream_executor_h",
773+
"//xla/tsl/lib/core:status_test_util",
774+
"//xla/tsl/platform:statusor",
775+
"//xla/tsl/util/proto:proto_matchers",
776+
"@com_google_absl//absl/status:status_matchers",
777+
"@com_google_absl//absl/status:statusor",
778+
"@com_google_googletest//:gtest_main",
779+
],
780+
)
781+
633782
cc_library(
634783
name = "factory_rocm",
635784
srcs = ["factory_rocm.cc"],
@@ -638,8 +787,8 @@ cc_library(
638787
"rocm-only",
639788
],
640789
deps = [
641-
":cublas",
642790
":factory",
791+
":rocblas",
643792
":triton",
644793
"//xla:xla_proto_cc",
645794
"//xla/backends/autotuner:codegen_backend",

xla/backends/gpu/autotuner/factory_rocm.cc

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,16 @@ See the License for the specific language governing permissions and
1313
limitations under the License.
1414
==============================================================================*/
1515

16-
#ifndef TENSORFLOW_COMPILER_XLA_BACKENDS_GPU_AUTOTUNER_CUDA_FACTORY_H_
17-
#define TENSORFLOW_COMPILER_XLA_BACKENDS_GPU_AUTOTUNER_CUDA_FACTORY_H_
16+
#ifndef TENSORFLOW_COMPILER_XLA_BACKENDS_GPU_AUTOTUNER_ROCM_FACTORY_H_
17+
#define TENSORFLOW_COMPILER_XLA_BACKENDS_GPU_AUTOTUNER_ROCM_FACTORY_H_
1818

1919
#include <memory>
2020
#include <vector>
2121

2222
#include "mlir/IR/MLIRContext.h"
2323
#include "xla/backends/autotuner/codegen_backend.h"
24-
#include "xla/backends/gpu/autotuner/cublas.h"
2524
#include "xla/backends/gpu/autotuner/factory.h"
25+
#include "xla/backends/gpu/autotuner/rocblas.h"
2626
#include "xla/backends/gpu/autotuner/triton.h"
2727
#include "xla/service/compiler.h"
2828
#include "xla/stream_executor/platform/platform_object_registry.h"
@@ -42,7 +42,7 @@ std::vector<std::unique_ptr<CodegenBackend>> GetCodegenBackendsForROCm(
4242
std::vector<std::unique_ptr<CodegenBackend>> backends;
4343
backends.push_back(std::make_unique<TritonBackend>(
4444
debug_options, compiler, target_config, mlir_context));
45-
backends.push_back(std::make_unique<CublasBackend>(
45+
backends.push_back(std::make_unique<RocblasBackend>(
4646
stream_executor, debug_options, compiler, target_config));
4747
return backends;
4848
}
@@ -66,4 +66,4 @@ STREAM_EXECUTOR_REGISTER_OBJECT_STATICALLY(GetFissionBackendsROCmRegistration,
6666
} // namespace gpu
6767
} // namespace xla
6868

69-
#endif // TENSORFLOW_COMPILER_XLA_BACKENDS_GPU_AUTOTUNER_CUDA_FACTORY_H_
69+
#endif // TENSORFLOW_COMPILER_XLA_BACKENDS_GPU_AUTOTUNER_ROCM_FACTORY_H_
Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
/* Copyright 2025 The OpenXLA Authors.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include "xla/backends/gpu/autotuner/hipblaslt.h"
17+
18+
#include <cstdint>
19+
#include <memory>
20+
#include <utility>
21+
#include <vector>
22+
23+
#include "absl/log/log.h"
24+
#include "absl/status/status.h"
25+
#include "absl/status/statusor.h"
26+
#include "xla/autotuning.pb.h"
27+
#include "xla/backends/autotuner/codegen_backend.h"
28+
#include "xla/hlo/ir/hlo_instruction.h"
29+
#include "xla/service/compiler.h"
30+
#include "xla/service/gpu/backend_configs.pb.h"
31+
#include "xla/service/gpu/cublas_cudnn.h"
32+
#include "xla/service/gpu/matmul_utils.h"
33+
#include "xla/shape.h"
34+
#include "xla/shape_util.h"
35+
#include "xla/stream_executor/blas.h"
36+
#include "xla/stream_executor/device_description.h"
37+
#include "xla/stream_executor/gpu/gpu_blas_lt.h"
38+
#include "xla/stream_executor/stream.h"
39+
#include "xla/tsl/platform/errors.h"
40+
#include "xla/tsl/platform/statusor.h"
41+
#include "xla/util.h"
42+
43+
namespace xla {
44+
namespace gpu {
45+
46+
namespace se = ::stream_executor;
47+
using se::gpu::BlasLt;
48+
49+
using HipblasLtBackendConfig = AutotuneResult::GemmKey;
50+
51+
namespace {
52+
53+
absl::StatusOr<BlasLt::Epilogue> AsBlasLtEpilogue(
54+
GemmBackendConfig_Epilogue epilogue) {
55+
switch (epilogue) {
56+
case GemmBackendConfig::DEFAULT:
57+
return BlasLt::Epilogue::kDefault;
58+
case GemmBackendConfig::RELU:
59+
return BlasLt::Epilogue::kReLU;
60+
case GemmBackendConfig::GELU:
61+
return BlasLt::Epilogue::kGELU;
62+
case GemmBackendConfig::GELU_AUX:
63+
return BlasLt::Epilogue::kGELUWithAux;
64+
case GemmBackendConfig::SILU:
65+
return BlasLt::Epilogue::kSILU;
66+
case GemmBackendConfig::BIAS:
67+
return BlasLt::Epilogue::kBias;
68+
case GemmBackendConfig::BIAS_RELU:
69+
return BlasLt::Epilogue::kBiasThenReLU;
70+
case GemmBackendConfig::BIAS_GELU:
71+
return BlasLt::Epilogue::kBiasThenGELU;
72+
case GemmBackendConfig::BIAS_GELU_AUX:
73+
return BlasLt::Epilogue::kBiasThenGELUWithAux;
74+
case GemmBackendConfig::BIAS_SILU:
75+
return BlasLt::Epilogue::kBiasThenSILU;
76+
default:
77+
return Internal("Unsupported Epilogue.");
78+
}
79+
}
80+
81+
} // namespace
82+
83+
bool HipblasLtBackend::IsSupported(const HloInstruction& instr) {
84+
return IsCublasLtMatmul(instr) || IsCublasLtMatmulF8(instr);
85+
}
86+
87+
absl::StatusOr<std::vector<std::unique_ptr<BackendConfig>>>
88+
HipblasLtBackend::GetSupportedConfigs(const HloInstruction& instr) {
89+
if (!IsSupported(instr)) {
90+
return std::vector<std::unique_ptr<BackendConfig>>();
91+
}
92+
93+
GpuBackendConfig gpu_config =
94+
instr.backend_config<GpuBackendConfig>().value();
95+
const GemmBackendConfig& backend_config = gpu_config.gemm_backend_config();
96+
97+
TF_ASSIGN_OR_RETURN(
98+
GemmConfig gemm_config,
99+
GemmConfig::For(
100+
&instr, target_config().device_description.gpu_compute_capability()));
101+
102+
TF_ASSIGN_OR_RETURN(BlasLt::Epilogue epilogue,
103+
AsBlasLtEpilogue(backend_config.epilogue()));
104+
105+
TF_ASSIGN_OR_RETURN(std::unique_ptr<se::Stream> stream,
106+
stream_executor()->CreateStream());
107+
108+
TF_ASSIGN_OR_RETURN(
109+
std::unique_ptr<BlasLt::MatmulPlan> plan,
110+
se::gpu::BlasLt::GetMatmulPlan(stream.get(), gemm_config, epilogue));
111+
112+
const Shape& output_shape = instr.shape();
113+
if (!output_shape.IsTuple() || output_shape.tuple_shapes().empty()) {
114+
return Internal(
115+
"Invalid shape for HipblasLt matmul: output is not a non-empty tuple.");
116+
}
117+
// The last element of the output tuple is the workspace.
118+
const int64_t workspace_size =
119+
ShapeUtil::ByteSizeOf(output_shape.tuple_shapes().back());
120+
121+
TF_ASSIGN_OR_RETURN(
122+
std::vector<BlasLt::MatmulAlgorithm> algorithms,
123+
plan->GetAlgorithms(stream.get(), GemmConfig::kNumAlgorithms,
124+
workspace_size));
125+
int num_algorithms = algorithms.size();
126+
std::vector<std::unique_ptr<BackendConfig>> configs;
127+
configs.reserve(num_algorithms);
128+
for (int i = 0; i < num_algorithms; ++i) {
129+
HipblasLtBackendConfig gemm_key;
130+
gemm_key.set_algorithm(i);
131+
gemm_key.set_autotune_workspace_size(workspace_size);
132+
auto any = std::make_unique<google::protobuf::Any>();
133+
any->PackFrom(gemm_key);
134+
configs.push_back(std::move(any));
135+
}
136+
137+
return configs;
138+
}
139+
140+
absl::StatusOr<std::unique_ptr<BackendConfig>>
141+
HipblasLtBackend::GetDefaultConfig(const HloInstruction& instr) {
142+
if (!IsSupported(instr)) {
143+
return absl::InvalidArgumentError(
144+
"Not a HipblasLt custom call instruction.");
145+
}
146+
147+
AutotuneResult::GemmKey gemm_key;
148+
gemm_key.set_algorithm(0);
149+
auto any = std::make_unique<google::protobuf::Any>();
150+
any->PackFrom(gemm_key);
151+
return any;
152+
}
153+
154+
absl::Status HipblasLtBackend::ApplyConfig(HloInstruction& instr,
155+
const BackendConfig& config) {
156+
HipblasLtBackendConfig gemm_key;
157+
if (!config.UnpackTo(&gemm_key)) {
158+
return absl::InvalidArgumentError(
159+
"Failed to unpack HipblasLtBackendConfig from Any.");
160+
}
161+
TF_ASSIGN_OR_RETURN(GpuBackendConfig gpu_config,
162+
instr.backend_config<GpuBackendConfig>());
163+
GemmBackendConfig& backend_config = *gpu_config.mutable_gemm_backend_config();
164+
backend_config.set_selected_algorithm(gemm_key.algorithm());
165+
backend_config.set_autotune_workspace_size(
166+
gemm_key.autotune_workspace_size());
167+
TF_RETURN_IF_ERROR(instr.set_backend_config(std::move(gpu_config)));
168+
return absl::OkStatus();
169+
}
170+
171+
} // namespace gpu
172+
} // namespace xla

0 commit comments

Comments
 (0)