Skip to content

Commit de6c87c

Browse files
[Autotuner] Extend with missing features.
- Doesn't autotune is there is a config already assigned - Add CustomKernel auotuning support - Enable cubLASLt fallback for cublas backend. PiperOrigin-RevId: 839315624
1 parent 2e19703 commit de6c87c

File tree

2 files changed

+29
-7
lines changed

2 files changed

+29
-7
lines changed

xla/service/gpu/autotuning/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ cc_library(
142142
"//xla:xla_proto_cc",
143143
"//xla/backends/autotuner:codegen_backend",
144144
"//xla/backends/gpu/autotuner:cublas",
145+
"//xla/backends/gpu/autotuner:custom_kernel",
145146
"//xla/backends/gpu/autotuner:fission_backend",
146147
"//xla/backends/gpu/autotuner:triton",
147148
"//xla/backends/gpu/codegen/triton:tma_utils",

xla/service/gpu/autotuning/gemm_fusion_autotuner.cc

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ limitations under the License.
4747
#include "xla/autotuning.pb.h"
4848
#include "xla/backends/autotuner/codegen_backend.h"
4949
#include "xla/backends/gpu/autotuner/cublas.h"
50+
#include "xla/backends/gpu/autotuner/custom_kernel.h"
5051
#include "xla/backends/gpu/autotuner/fission_backend.h"
5152
#include "xla/backends/gpu/autotuner/triton.h"
5253
#include "xla/backends/gpu/runtime/buffer_comparator.h"
@@ -141,20 +142,29 @@ using ProfilingOutput = AutotunerCompileUtil::ProfilingOutput;
141142
namespace {
142143

143144
std::unique_ptr<HloPassPipeline> GetCublasRewriterPipeline(
144-
const se::DeviceDescription& device_description) {
145+
const se::DeviceDescription* device_description) {
145146
auto pipeline = std::make_unique<HloPassPipeline>("cublas_rewriter_pipeline");
146147
pipeline->AddPass(std::make_unique<DotAlgorithmRewriter>());
147148
for (GemmRewriterOptions::DType dtype :
148149
{GemmRewriterOptions::DType::kFp8Only,
149150
GemmRewriterOptions::DType::kNonFp8Only}) {
150151
auto gemm_rewriter = std::make_unique<GemmRewriter>(
151-
device_description.gpu_compute_capability(),
152-
device_description.runtime_version(), GemmRewriterOptions{dtype});
152+
device_description->gpu_compute_capability(),
153+
device_description->runtime_version(), GemmRewriterOptions{dtype});
153154
pipeline->AddPass(std::move(gemm_rewriter));
154155
}
155156
return pipeline;
156157
}
157158

159+
std::unique_ptr<HloPassPipeline> GetCustomKernelRewriterPipeline(
160+
const se::DeviceDescription* device_description) {
161+
auto pipeline =
162+
std::make_unique<HloPassPipeline>("custom_kernel_rewriter_pipeline");
163+
pipeline->AddPass(
164+
std::make_unique<CustomKernelFusionRewriter>(device_description));
165+
return pipeline;
166+
}
167+
158168
using AutoTuneCacheKeyCount = absl::flat_hash_map<AutotuneCacheKey, uint64_t>;
159169

160170
using KeysAndInstructions =
@@ -1698,8 +1708,15 @@ absl::StatusOr<bool> GemmFusionAutotuner::RunViaNewInfra(
16981708
backends.push_back(std::make_unique<FissionBackend>(
16991709
&debug_options, compiler.get(), target_config.get(),
17001710
std::make_unique<CublasBackend>(stream_exec, &debug_options,
1701-
compiler.get(), target_config.get()),
1702-
GetCublasRewriterPipeline(target_config->device_description),
1711+
compiler.get(), target_config.get(),
1712+
/*fp8_lt_fallback=*/true),
1713+
GetCublasRewriterPipeline(&target_config->device_description),
1714+
mlir_context_));
1715+
backends.push_back(std::make_unique<FissionBackend>(
1716+
&debug_options, compiler.get(), target_config.get(),
1717+
std::make_unique<CustomKernelBackend>(
1718+
stream_exec, &debug_options, compiler.get(), target_config.get()),
1719+
GetCustomKernelRewriterPipeline(&target_config->device_description),
17031720
mlir_context_));
17041721
auto should_autotune = [](const HloInstruction& instruction) -> bool {
17051722
if (instruction.opcode() != HloOpcode::kFusion) {
@@ -1708,8 +1725,12 @@ absl::StatusOr<bool> GemmFusionAutotuner::RunViaNewInfra(
17081725
auto gpu_config = instruction.backend_config<GpuBackendConfig>();
17091726
const FusionBackendConfig& backend_config =
17101727
gpu_config->fusion_backend_config();
1711-
if (backend_config.kind() == kTritonGemmFusionKind ||
1712-
backend_config.kind() == kCuDnnFusionKind) {
1728+
bool is_unassigned_triton =
1729+
backend_config.kind() == kTritonGemmFusionKind &&
1730+
!backend_config.has_triton_gemm_config();
1731+
bool is_unassigned_custom = backend_config.kind() == kCustomFusionKind &&
1732+
!backend_config.has_custom_fusion_config();
1733+
if (is_unassigned_triton || is_unassigned_custom) {
17131734
return true;
17141735
}
17151736
return false;

0 commit comments

Comments
 (0)