@@ -1946,7 +1946,7 @@ index 0aa610fc9..3c4b34ace 100644
1946
1946
MatrixIsColumnMajor(instr, gemm_backend_config));
1947
1947
1948
1948
diff --git a/xla/service/gpu/gpu_compiler.cc b/xla/service/gpu/gpu_compiler.cc
1949
- index d0c20aa1c..f975f8122 100644
1949
+ index d0c20aa1c..2030ba0b1 100644
1950
1950
--- a/xla/service/gpu/gpu_compiler.cc
1951
1951
+++ b/xla/service/gpu/gpu_compiler.cc
1952
1952
@@ -268,6 +268,8 @@ limitations under the License.
@@ -2012,8 +2012,14 @@ index d0c20aa1c..f975f8122 100644
2012
2012
TF_RETURN_IF_ERROR(OptimizeHloConvolutionCanonicalization(
2013
2013
hlo_module, gpu_version, dnn_version, options.device_allocator));
2014
2014
2015
- @@ -1414,7 +1425,8 @@ absl::Status GpuCompiler::OptimizeHloPostLayoutAssignment(
2016
- pipeline.AddPass<GemmRewriter>(gpu_version, /*f8_rewrite=*/true);
2015
+ @@ -1411,10 +1422,13 @@ absl::Status GpuCompiler::OptimizeHloPostLayoutAssignment(
2016
+
2017
+ // Rewrite FP8 GEMMs ahead of Triton which currently lacks support for FP8
2018
+ // and may rewrite quantized FP8 GEMMs as higher-precision GEMMs.
2019
+ - pipeline.AddPass<GemmRewriter>(gpu_version, /*f8_rewrite=*/true);
2020
+ + // pipeline.AddPass<GemmRewriter>(gpu_version, /*f8_rewrite=*/true);
2021
+ + // SYCL doesn't support fp8 gemm yet.
2022
+ + pipeline.AddPass<GemmRewriter>(gpu_version, /*f8_rewrite=*/false);
2017
2023
if (debug_options.xla_gpu_enable_triton_gemm() && cuda_cc != nullptr &&
2018
2024
cuda_cc->IsAtLeast(se::CudaComputeCapability::AMPERE)) {
2019
2025
- pipeline.AddPass<GemmFusion>(gpu_version);
@@ -2022,7 +2028,7 @@ index d0c20aa1c..f975f8122 100644
2022
2028
}
2023
2029
// Rewrite non-FP8 GEMMs.
2024
2030
pipeline.AddPass<GemmRewriter>(gpu_version, /*f8_rewrite=*/false);
2025
- @@ -1436,8 +1448 ,9 @@ absl::Status GpuCompiler::OptimizeHloPostLayoutAssignment(
2031
+ @@ -1436,8 +1450 ,9 @@ absl::Status GpuCompiler::OptimizeHloPostLayoutAssignment(
2026
2032
if (debug_options.xla_gpu_enable_triton_softmax_fusion() &&
2027
2033
cuda_cc != nullptr &&
2028
2034
cuda_cc->IsAtLeast(se::CudaComputeCapability::AMPERE)) {
@@ -2034,7 +2040,7 @@ index d0c20aa1c..f975f8122 100644
2034
2040
}
2035
2041
2036
2042
pipeline.AddPass<ReductionDimensionGrouper>();
2037
- @@ -1770,6 +1783 ,11 @@ GpuCompiler::CompileSingleModule(const HloModuleConfig& module_config,
2043
+ @@ -1770,6 +1785 ,11 @@ GpuCompiler::CompileSingleModule(const HloModuleConfig& module_config,
2038
2044
2039
2045
// Write PTX to IR dump directory, if IR dumping was requested.
2040
2046
if (should_dump) {
@@ -2046,31 +2052,31 @@ index d0c20aa1c..f975f8122 100644
2046
2052
absl::string_view ptx = result.asm_text;
2047
2053
if (debug_module) {
2048
2054
DumpToFileInDirOrStdout(*debug_module, "",
2049
- @@ -2084,6 +2102 ,7 @@ absl::StatusOr<std::unique_ptr<Executable>> GpuCompiler::RunBackend(
2055
+ @@ -2084,6 +2104 ,7 @@ absl::StatusOr<std::unique_ptr<Executable>> GpuCompiler::RunBackend(
2050
2056
absl::StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
2051
2057
GpuCompiler::CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,
2052
2058
const AotCompilationOptions& options) {
2053
2059
+ #if 0
2054
2060
#if GOOGLE_CUDA
2055
2061
CHECK(options.PlatformId() == se::cuda::kCudaPlatformId);
2056
2062
#elif TENSORFLOW_USE_ROCM
2057
- @@ -2137,6 +2156 ,7 @@ GpuCompiler::CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,
2063
+ @@ -2137,6 +2158 ,7 @@ GpuCompiler::CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,
2058
2064
}
2059
2065
2060
2066
return std::move(results);
2061
2067
+ #endif
2062
2068
}
2063
2069
2064
2070
HloCostAnalysis::ShapeSizeFunction GpuCompiler::ShapeSizeBytesFunction() const {
2065
- @@ -2148,6 +2168 ,7 @@ HloCostAnalysis::ShapeSizeFunction GpuCompiler::ShapeSizeBytesFunction() const {
2071
+ @@ -2148,6 +2170 ,7 @@ HloCostAnalysis::ShapeSizeFunction GpuCompiler::ShapeSizeBytesFunction() const {
2066
2072
2067
2073
absl::StatusOr<std::unique_ptr<AotCompilationResult>> GpuCompiler::Export(
2068
2074
Executable* executable) const {
2069
2075
+ #if 0
2070
2076
auto* gpu_executable = tensorflow::down_cast<GpuExecutable*>(executable);
2071
2077
if (!gpu_executable) return Internal("GpuExecutable is null");
2072
2078
2073
- @@ -2155,6 +2176 ,8 @@ absl::StatusOr<std::unique_ptr<AotCompilationResult>> GpuCompiler::Export(
2079
+ @@ -2155,6 +2178 ,8 @@ absl::StatusOr<std::unique_ptr<AotCompilationResult>> GpuCompiler::Export(
2074
2080
&gpu_executable->module(), gpu_executable->buffer_assignment(),
2075
2081
gpu_executable->text(), gpu_executable->binary(),
2076
2082
gpu_executable->dnn_compiled_graphs());
@@ -2079,7 +2085,7 @@ index d0c20aa1c..f975f8122 100644
2079
2085
}
2080
2086
2081
2087
absl::Status GpuCompiler::RunPostSchedulingPipelines(
2082
- @@ -2215,13 +2238 ,18 @@ absl::Status GpuCompiler::RunPostSchedulingPipelines(
2088
+ @@ -2215,13 +2240 ,18 @@ absl::Status GpuCompiler::RunPostSchedulingPipelines(
2083
2089
auto driver_version = se::gpu::GpuDriver::GetDriverVersion();
2084
2090
#if GOOGLE_CUDA
2085
2091
constexpr int toolkit_version = CUDA_VERSION;
0 commit comments