Skip to content

Commit b170c49

Browse files
[Autotuner] Add Fp8 cuBLASLt fallback for cublas backend.
- This is to match the current behavior in XLA, gemm-rewriter already has lots of checks to rewrite to cublasLt matmul. - We are anyway trying to deprecate legacy cuBLAS and enable cuBLASLt. PiperOrigin-RevId: 839273756
1 parent 4ed551d commit b170c49

File tree

5 files changed

+125
-18
lines changed

5 files changed

+125
-18
lines changed

xla/backends/gpu/autotuner/BUILD

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -832,7 +832,6 @@ xla_test(
832832
":fission_backend",
833833
":gpu_codegen_backend",
834834
"//xla/backends/autotuner:codegen_backend",
835-
"//xla/hlo/analysis:symbolic_expr",
836835
"//xla/hlo/ir:hlo",
837836
"//xla/hlo/pass:hlo_pass_pipeline",
838837
"//xla/hlo/testlib:hlo_hardware_independent_test_base",

xla/backends/gpu/autotuner/cublas.cc

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,9 @@ limitations under the License.
2525
#include "xla/autotuning.pb.h"
2626
#include "xla/backends/autotuner/codegen_backend.h"
2727
#include "xla/hlo/ir/hlo_instruction.h"
28-
#include "xla/hlo/ir/hlo_opcode.h"
29-
#include "xla/hlo/utils/hlo_query.h"
3028
#include "xla/service/gpu/backend_configs.pb.h"
3129
#include "xla/service/gpu/cublas_cudnn.h"
3230
#include "xla/service/gpu/matmul_utils.h"
33-
#include "xla/service/gpu/transforms/dot_algorithm_rewriter.h"
34-
#include "xla/service/gpu/transforms/gemm_rewriter.h"
35-
#include "xla/service/hlo_cost_analysis.h"
3631
#include "xla/stream_executor/blas.h"
3732
#include "xla/stream_executor/device_description.h"
3833
#include "xla/stream_executor/device_memory.h"
@@ -49,10 +44,19 @@ namespace se = ::stream_executor;
4944

5045
absl::StatusOr<std::vector<std::unique_ptr<BackendConfig>>>
5146
CublasBackend::GetSupportedConfigs(const HloInstruction& instr) {
52-
if (!IsLegacyCublasMatmul(instr)) {
47+
if (!IsSupported(instr)) {
5348
return std::vector<std::unique_ptr<BackendConfig>>();
5449
}
5550

51+
if (ShouldUseCublasLt(instr)) {
52+
std::vector<std::unique_ptr<BackendConfig>> configs;
53+
AutotuneResult::GemmKey gemm_key;
54+
gemm_key.set_algorithm(0);
55+
configs.push_back(std::make_unique<google::protobuf::Any>());
56+
configs.back()->PackFrom(gemm_key);
57+
return configs;
58+
}
59+
5660
std::unique_ptr<se::DeviceMemoryAllocator> allocator =
5761
std::make_unique<se::StreamExecutorMemoryAllocator>(stream_executor());
5862
TF_ASSIGN_OR_RETURN(
@@ -126,14 +130,16 @@ CublasBackend::GetSupportedConfigs(const HloInstruction& instr) {
126130

127131
absl::StatusOr<std::unique_ptr<BackendConfig>> CublasBackend::GetDefaultConfig(
128132
const HloInstruction& instr) {
129-
if (!IsLegacyCublasMatmul(instr)) {
133+
if (!IsSupported(instr)) {
130134
return absl::InvalidArgumentError(
131135
"CublasBackend does not support this instruction.");
132136
}
133-
134137
AutotuneResult::GemmKey gemm_key;
135138
gemm_key.set_algorithm(se::blas::kDefaultAlgorithm);
136139
auto any = std::make_unique<google::protobuf::Any>();
140+
if (ShouldUseCublasLt(instr)) {
141+
gemm_key.set_algorithm(0);
142+
}
137143
any->PackFrom(gemm_key);
138144
return any;
139145
}
@@ -154,7 +160,11 @@ absl::Status CublasBackend::ApplyConfig(HloInstruction& instr,
154160
}
155161

156162
bool CublasBackend::IsSupported(const HloInstruction& instr) {
157-
return IsLegacyCublasMatmul(instr);
163+
return IsLegacyCublasMatmul(instr) || ShouldUseCublasLt(instr);
164+
}
165+
166+
bool CublasBackend::ShouldUseCublasLt(const HloInstruction& instr) {
167+
return fp8_lt_fallback_ && IsCublasLtMatmulF8(instr);
158168
}
159169

160170
} // namespace gpu

xla/backends/gpu/autotuner/cublas.h

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@ limitations under the License.
3232
namespace xla {
3333
namespace gpu {
3434

35-
// A codegen backend for cuBLAS.
35+
// A codegen backend for cuBLAS, with configurable fallback to cuBLAS LT for F8
36+
// matmuls.
3637
// This backend is used to autotune cuBLAS algorithms.
3738
//
3839
// Cublas calls are represented as custom-call instructions, with and
@@ -48,9 +49,11 @@ class CublasBackend : public GpuCodegenBackend {
4849
public:
4950
explicit CublasBackend(stream_executor::StreamExecutor* stream_executor,
5051
const DebugOptions* debug_options, Compiler* compiler,
51-
const Compiler::GpuTargetConfig* target_config)
52+
const Compiler::GpuTargetConfig* target_config,
53+
bool fp8_lt_fallback = false)
5254
: GpuCodegenBackend("Cublas", debug_options, compiler, target_config,
53-
stream_executor) {}
55+
stream_executor),
56+
fp8_lt_fallback_(fp8_lt_fallback) {}
5457

5558
absl::StatusOr<std::vector<std::unique_ptr<BackendConfig>>>
5659
GetSupportedConfigs(const HloInstruction& instr) override;
@@ -62,7 +65,10 @@ class CublasBackend : public GpuCodegenBackend {
6265
const BackendConfig& config) override;
6366

6467
private:
68+
bool ShouldUseCublasLt(const HloInstruction& instr);
69+
6570
bool IsSupported(const HloInstruction& instr) override;
71+
bool fp8_lt_fallback_;
6672
};
6773

6874
} // namespace gpu

xla/backends/gpu/autotuner/cublas_test.cc

Lines changed: 65 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,10 @@ namespace gpu {
4444

4545
using CublasBackendConfig = AutotuneResult::GemmKey;
4646

47+
using absl_testing::IsOk;
48+
using absl_testing::IsOkAndHolds;
49+
using ::testing::IsEmpty;
50+
using ::testing::Not;
4751
using ::tsl::proto_testing::EqualsProto;
4852

4953
const char kCublasCustomCallHlo[] = R"(
@@ -68,6 +72,48 @@ const char kCublasCustomCallHlo[] = R"(
6872
ROOT %get-tuple-element = f32[100,100]{1,0} get-tuple-element(%custom-call.1), index=0
6973
})";
7074

75+
const char kCublasLtCustomCallHlo[] = R"(
76+
HloModule test, entry_computation_layout={(f8e4m3fn[16,32]{1,0}, f8e5m2[32,16]{1,0}, f32[], f32[])->f32[16,16]{1,0}}
77+
78+
ENTRY %test (x: f8e4m3fn[16,32], y: f8e5m2[32,16], x_scale: f32[], y_scale: f32[]) -> f32[16,16] {
79+
%x = f8e4m3fn[16,32]{1,0} parameter(0)
80+
%y = f8e5m2[32,16]{1,0} parameter(1)
81+
%transpose = f8e5m2[16,32]{1,0} transpose(%y), dimensions={1,0}
82+
%x_scale = f32[] parameter(2)
83+
%y_scale = f32[] parameter(3)
84+
%cublas-gemm.1 = (f32[16,16]{1,0}, s8[33554432]{0}) custom-call(%x, %transpose, %x_scale, %y_scale),
85+
custom_call_target="__cublas$lt$matmul$f8",
86+
backend_config={
87+
"operation_queue_id":"0",
88+
"wait_on_operation_queues":[],
89+
"gemm_backend_config":{
90+
"alpha_real":1,
91+
"beta":0,
92+
"dot_dimension_numbers":{
93+
"lhs_contracting_dimensions":["1"],
94+
"rhs_contracting_dimensions":["1"],
95+
"lhs_batch_dimensions":[],
96+
"rhs_batch_dimensions":[]
97+
},
98+
"alpha_imag":0,
99+
"precision_config":{
100+
"operand_precision":["DEFAULT","DEFAULT"],
101+
"algorithm":"ALG_UNSET"
102+
},
103+
"epilogue":"DEFAULT",
104+
"lhs_stride":"512",
105+
"rhs_stride":"512",
106+
"grad_x":false,
107+
"grad_y":false,
108+
"damax_output":false
109+
},
110+
"force_earliest_schedule":false,
111+
"reification_cost":[],
112+
"device_type":"DEVICE_TYPE_INVALID"
113+
}
114+
ROOT %get-tuple-element = f32[16,16]{1,0} get-tuple-element(%cublas-gemm.1), index=0
115+
})";
116+
71117
const char kUnsupportedHlo[] = R"(
72118
HloModule module
73119
@@ -122,8 +168,22 @@ TEST_F(CublasBackendTest, GetSupportedConfigsFromCublasCustomCall) {
122168
absl::StatusOr<std::vector<std::unique_ptr<BackendConfig>>> configs =
123169
backend_.GetSupportedConfigs(
124170
(*hlo_module->entry_computation()->root_instruction()->operand(0)));
125-
EXPECT_THAT(configs, absl_testing::IsOk());
126-
EXPECT_GT(configs.value().size(), 0);
171+
EXPECT_THAT(configs, IsOkAndHolds(Not(IsEmpty())));
172+
}
173+
174+
TEST_F(CublasBackendTest, CublasLtCustomCall) {
175+
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> hlo_module,
176+
ParseAndReturnVerifiedModule(kCublasLtCustomCallHlo));
177+
const HloInstruction* instr =
178+
hlo_module->entry_computation()->root_instruction()->operand(0);
179+
CublasBackend backend(stream_executor_, &debug_options_, &compiler_,
180+
&target_config_, /*fp8_lt_fallback=*/true);
181+
absl::StatusOr<std::vector<std::unique_ptr<BackendConfig>>> configs =
182+
backend.GetSupportedConfigs(*instr);
183+
EXPECT_THAT(configs, IsOkAndHolds(Not(IsEmpty())));
184+
185+
EXPECT_THAT(backend.GetDefaultConfig(*instr), IsOk());
186+
EXPECT_THAT(backend.Compile(*instr, *configs.value()[0]), IsOk());
127187
}
128188

129189
TEST_F(CublasBackendTest,
@@ -133,7 +193,7 @@ TEST_F(CublasBackendTest,
133193
absl::StatusOr<std::vector<std::unique_ptr<BackendConfig>>> configs =
134194
backend_.GetSupportedConfigs(
135195
(*hlo_module->entry_computation()->root_instruction()));
136-
EXPECT_THAT(configs, absl_testing::IsOkAndHolds(testing::SizeIs(0)));
196+
EXPECT_THAT(configs, IsOkAndHolds(testing::SizeIs(0)));
137197
}
138198

139199
TEST_F(CublasBackendTest, GetDefaultConfigFromCublasCustomCall) {
@@ -162,7 +222,7 @@ TEST_F(CublasBackendTest, ApplyConfig) {
162222
any));
163223
EXPECT_THAT(RunFileCheck(hlo_module->ToString(),
164224
"CHECK: \"selected_algorithm\":\"2\""),
165-
absl_testing::IsOkAndHolds(true));
225+
IsOkAndHolds(true));
166226
}
167227

168228
TEST_F(CublasBackendTest, Compile) {
@@ -174,7 +234,7 @@ TEST_F(CublasBackendTest, Compile) {
174234
*(module->entry_computation()->root_instruction()->operand(0))));
175235
absl::StatusOr<std::unique_ptr<Executable>> executable = backend_.Compile(
176236
*(module->entry_computation()->root_instruction()), *config);
177-
EXPECT_THAT(executable, absl_testing::IsOk());
237+
EXPECT_THAT(executable, IsOk());
178238
}
179239

180240
} // namespace gpu

xla/backends/gpu/autotuner/fission_backend_test.cc

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,21 @@ const char kTritonFusionHlo[] = R"(
7474
backend_config={"fusion_backend_config":{"kind":"__triton_gemm"}}
7575
})";
7676

77+
const char kF8TritonFusionHlo[] = R"(
78+
HloModule o
79+
80+
gemm_fusion {
81+
p0 = f8e4m3fn[64,6144]{1,0} parameter(0)
82+
p1 = f8e4m3fn[64,6144]{1,0} parameter(1)
83+
ROOT %dot.0 = f32[64,64]{1,0} dot(p0, p1), lhs_contracting_dims={1}, rhs_contracting_dims={1}
84+
}
85+
86+
ENTRY main {
87+
p0 = f8e4m3fn[64,6144]{1,0} parameter(0)
88+
p1 = f8e4m3fn[64,6144]{1,0} parameter(1)
89+
ROOT %dot.0 = f32[64,64]{1,0} fusion(p0, p1), kind=kCustom, calls=gemm_fusion, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"fusion_backend_config":{"kind":"__triton_gemm"},"force_earliest_schedule":false}
90+
})";
91+
7792
const char kUnsupportedFusionHlo[] = R"(
7893
HloModule module
7994
computation {
@@ -144,6 +159,15 @@ class FissionTest : public HloHardwareIndependentTestBase,
144159
compiler, target_config);
145160
}
146161

162+
// Static helper to create a CublasBackend.
163+
static std::unique_ptr<GpuCodegenBackend> CreateCublasBackendWiithF8Fallback(
164+
se::StreamExecutor* stream_executor, const DebugOptions* debug_options,
165+
Compiler* compiler, const Compiler::GpuTargetConfig* target_config) {
166+
return std::make_unique<CublasBackend>(stream_executor, debug_options,
167+
compiler, target_config,
168+
/*enable_f8_fallback=*/true);
169+
}
170+
147171
// Static helper to create a CustomKernelBackend.
148172
static std::unique_ptr<GpuCodegenBackend> CreateCustomKernelBackend(
149173
se::StreamExecutor* stream_executor, const DebugOptions* debug_options,
@@ -245,6 +269,14 @@ INSTANTIATE_TEST_SUITE_P(
245269
{"custom_call_target=\"__cublas$gemm\"",
246270
"\"selected_algorithm\":\"-1\""},
247271
/*expected_backend_name=*/"Cublas_fission"},
272+
{"TritonFusion_CublasLt_F8",
273+
kF8TritonFusionHlo,
274+
&FissionTest::GetCublasRewriterPipeline,
275+
&FissionTest::CreateCublasBackendWiithF8Fallback,
276+
/*expected_module_substrings=*/
277+
{"custom_call_target=\"__cublas$lt$matmul$f8\"",
278+
"\"selected_algorithm\":\"0\""},
279+
/*expected_backend_name=*/"Cublas_fission"},
248280
{"TritonFusion_CustomKernel",
249281
kTritonFusionHlo,
250282
&FissionTest::GetCustomKernelRewriterPipeline,

0 commit comments

Comments
 (0)