@@ -47,6 +47,7 @@ limitations under the License.
4747#include " xla/autotuning.pb.h"
4848#include " xla/backends/autotuner/codegen_backend.h"
4949#include " xla/backends/gpu/autotuner/cublas.h"
50+ #include " xla/backends/gpu/autotuner/custom_kernel.h"
5051#include " xla/backends/gpu/autotuner/fission_backend.h"
5152#include " xla/backends/gpu/autotuner/triton.h"
5253#include " xla/backends/gpu/runtime/buffer_comparator.h"
@@ -141,20 +142,29 @@ using ProfilingOutput = AutotunerCompileUtil::ProfilingOutput;
141142namespace {
142143
143144std::unique_ptr<HloPassPipeline> GetCublasRewriterPipeline (
144- const se::DeviceDescription& device_description) {
145+ const se::DeviceDescription* device_description) {
145146 auto pipeline = std::make_unique<HloPassPipeline>(" cublas_rewriter_pipeline" );
146147 pipeline->AddPass (std::make_unique<DotAlgorithmRewriter>());
147148 for (GemmRewriterOptions::DType dtype :
148149 {GemmRewriterOptions::DType::kFp8Only ,
149150 GemmRewriterOptions::DType::kNonFp8Only }) {
150151 auto gemm_rewriter = std::make_unique<GemmRewriter>(
151- device_description. gpu_compute_capability (),
152- device_description. runtime_version (), GemmRewriterOptions{dtype});
152+ device_description-> gpu_compute_capability (),
153+ device_description-> runtime_version (), GemmRewriterOptions{dtype});
153154 pipeline->AddPass (std::move (gemm_rewriter));
154155 }
155156 return pipeline;
156157}
157158
159+ std::unique_ptr<HloPassPipeline> GetCustomKernelRewriterPipeline (
160+ const se::DeviceDescription* device_description) {
161+ auto pipeline =
162+ std::make_unique<HloPassPipeline>(" custom_kernel_rewriter_pipeline" );
163+ pipeline->AddPass (
164+ std::make_unique<CustomKernelFusionRewriter>(device_description));
165+ return pipeline;
166+ }
167+
158168using AutoTuneCacheKeyCount = absl::flat_hash_map<AutotuneCacheKey, uint64_t >;
159169
160170using KeysAndInstructions =
@@ -1698,8 +1708,15 @@ absl::StatusOr<bool> GemmFusionAutotuner::RunViaNewInfra(
16981708 backends.push_back (std::make_unique<FissionBackend>(
16991709 &debug_options, compiler.get (), target_config.get (),
17001710 std::make_unique<CublasBackend>(stream_exec, &debug_options,
1701- compiler.get (), target_config.get ()),
1702- GetCublasRewriterPipeline (target_config->device_description ),
1711+ compiler.get (), target_config.get (),
1712+ /* fp8_lt_fallback=*/ true ),
1713+ GetCublasRewriterPipeline (&target_config->device_description ),
1714+ mlir_context_));
1715+ backends.push_back (std::make_unique<FissionBackend>(
1716+ &debug_options, compiler.get (), target_config.get (),
1717+ std::make_unique<CustomKernelBackend>(
1718+ stream_exec, &debug_options, compiler.get (), target_config.get ()),
1719+ GetCustomKernelRewriterPipeline (&target_config->device_description ),
17031720 mlir_context_));
17041721 auto should_autotune = [](const HloInstruction& instruction) -> bool {
17051722 if (instruction.opcode () != HloOpcode::kFusion ) {
@@ -1708,8 +1725,12 @@ absl::StatusOr<bool> GemmFusionAutotuner::RunViaNewInfra(
17081725 auto gpu_config = instruction.backend_config <GpuBackendConfig>();
17091726 const FusionBackendConfig& backend_config =
17101727 gpu_config->fusion_backend_config ();
1711- if (backend_config.kind () == kTritonGemmFusionKind ||
1712- backend_config.kind () == kCuDnnFusionKind ) {
1728+ bool is_unassigned_triton =
1729+ backend_config.kind () == kTritonGemmFusionKind &&
1730+ !backend_config.has_triton_gemm_config ();
1731+ bool is_unassigned_custom = backend_config.kind () == kCustomFusionKind &&
1732+ !backend_config.has_custom_fusion_config ();
1733+ if (is_unassigned_triton || is_unassigned_custom) {
17131734 return true ;
17141735 }
17151736 return false ;
0 commit comments