Skip to content

Commit f4716be

Browse files
author
lingzhi98
authored
Disable fp8 gemm (#391)
1 parent 584c780 commit f4716be

File tree

1 file changed

+16
-10
lines changed

1 file changed

+16
-10
lines changed

third_party/openxla.patch

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1946,7 +1946,7 @@ index 0aa610fc9..3c4b34ace 100644
19461946
MatrixIsColumnMajor(instr, gemm_backend_config));
19471947

19481948
diff --git a/xla/service/gpu/gpu_compiler.cc b/xla/service/gpu/gpu_compiler.cc
1949-
index d0c20aa1c..f975f8122 100644
1949+
index d0c20aa1c..2030ba0b1 100644
19501950
--- a/xla/service/gpu/gpu_compiler.cc
19511951
+++ b/xla/service/gpu/gpu_compiler.cc
19521952
@@ -268,6 +268,8 @@ limitations under the License.
@@ -2012,8 +2012,14 @@ index d0c20aa1c..f975f8122 100644
20122012
TF_RETURN_IF_ERROR(OptimizeHloConvolutionCanonicalization(
20132013
hlo_module, gpu_version, dnn_version, options.device_allocator));
20142014

2015-
@@ -1414,7 +1425,8 @@ absl::Status GpuCompiler::OptimizeHloPostLayoutAssignment(
2016-
pipeline.AddPass<GemmRewriter>(gpu_version, /*f8_rewrite=*/true);
2015+
@@ -1411,10 +1422,13 @@ absl::Status GpuCompiler::OptimizeHloPostLayoutAssignment(
2016+
2017+
// Rewrite FP8 GEMMs ahead of Triton which currently lacks support for FP8
2018+
// and may rewrite quantized FP8 GEMMs as higher-precision GEMMs.
2019+
- pipeline.AddPass<GemmRewriter>(gpu_version, /*f8_rewrite=*/true);
2020+
+ // pipeline.AddPass<GemmRewriter>(gpu_version, /*f8_rewrite=*/true);
2021+
+ // SYCL doesn't support fp8 gemm yet.
2022+
+ pipeline.AddPass<GemmRewriter>(gpu_version, /*f8_rewrite=*/false);
20172023
if (debug_options.xla_gpu_enable_triton_gemm() && cuda_cc != nullptr &&
20182024
cuda_cc->IsAtLeast(se::CudaComputeCapability::AMPERE)) {
20192025
- pipeline.AddPass<GemmFusion>(gpu_version);
@@ -2022,7 +2028,7 @@ index d0c20aa1c..f975f8122 100644
20222028
}
20232029
// Rewrite non-FP8 GEMMs.
20242030
pipeline.AddPass<GemmRewriter>(gpu_version, /*f8_rewrite=*/false);
2025-
@@ -1436,8 +1448,9 @@ absl::Status GpuCompiler::OptimizeHloPostLayoutAssignment(
2031+
@@ -1436,8 +1450,9 @@ absl::Status GpuCompiler::OptimizeHloPostLayoutAssignment(
20262032
if (debug_options.xla_gpu_enable_triton_softmax_fusion() &&
20272033
cuda_cc != nullptr &&
20282034
cuda_cc->IsAtLeast(se::CudaComputeCapability::AMPERE)) {
@@ -2034,7 +2040,7 @@ index d0c20aa1c..f975f8122 100644
20342040
}
20352041

20362042
pipeline.AddPass<ReductionDimensionGrouper>();
2037-
@@ -1770,6 +1783,11 @@ GpuCompiler::CompileSingleModule(const HloModuleConfig& module_config,
2043+
@@ -1770,6 +1785,11 @@ GpuCompiler::CompileSingleModule(const HloModuleConfig& module_config,
20382044

20392045
// Write PTX to IR dump directory, if IR dumping was requested.
20402046
if (should_dump) {
@@ -2046,31 +2052,31 @@ index d0c20aa1c..f975f8122 100644
20462052
absl::string_view ptx = result.asm_text;
20472053
if (debug_module) {
20482054
DumpToFileInDirOrStdout(*debug_module, "",
2049-
@@ -2084,6 +2102,7 @@ absl::StatusOr<std::unique_ptr<Executable>> GpuCompiler::RunBackend(
2055+
@@ -2084,6 +2104,7 @@ absl::StatusOr<std::unique_ptr<Executable>> GpuCompiler::RunBackend(
20502056
absl::StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
20512057
GpuCompiler::CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,
20522058
const AotCompilationOptions& options) {
20532059
+#if 0
20542060
#if GOOGLE_CUDA
20552061
CHECK(options.PlatformId() == se::cuda::kCudaPlatformId);
20562062
#elif TENSORFLOW_USE_ROCM
2057-
@@ -2137,6 +2156,7 @@ GpuCompiler::CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,
2063+
@@ -2137,6 +2158,7 @@ GpuCompiler::CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,
20582064
}
20592065

20602066
return std::move(results);
20612067
+#endif
20622068
}
20632069

20642070
HloCostAnalysis::ShapeSizeFunction GpuCompiler::ShapeSizeBytesFunction() const {
2065-
@@ -2148,6 +2168,7 @@ HloCostAnalysis::ShapeSizeFunction GpuCompiler::ShapeSizeBytesFunction() const {
2071+
@@ -2148,6 +2170,7 @@ HloCostAnalysis::ShapeSizeFunction GpuCompiler::ShapeSizeBytesFunction() const {
20662072

20672073
absl::StatusOr<std::unique_ptr<AotCompilationResult>> GpuCompiler::Export(
20682074
Executable* executable) const {
20692075
+#if 0
20702076
auto* gpu_executable = tensorflow::down_cast<GpuExecutable*>(executable);
20712077
if (!gpu_executable) return Internal("GpuExecutable is null");
20722078

2073-
@@ -2155,6 +2176,8 @@ absl::StatusOr<std::unique_ptr<AotCompilationResult>> GpuCompiler::Export(
2079+
@@ -2155,6 +2178,8 @@ absl::StatusOr<std::unique_ptr<AotCompilationResult>> GpuCompiler::Export(
20742080
&gpu_executable->module(), gpu_executable->buffer_assignment(),
20752081
gpu_executable->text(), gpu_executable->binary(),
20762082
gpu_executable->dnn_compiled_graphs());
@@ -2079,7 +2085,7 @@ index d0c20aa1c..f975f8122 100644
20792085
}
20802086

20812087
absl::Status GpuCompiler::RunPostSchedulingPipelines(
2082-
@@ -2215,13 +2238,18 @@ absl::Status GpuCompiler::RunPostSchedulingPipelines(
2088+
@@ -2215,13 +2240,18 @@ absl::Status GpuCompiler::RunPostSchedulingPipelines(
20832089
auto driver_version = se::gpu::GpuDriver::GetDriverVersion();
20842090
#if GOOGLE_CUDA
20852091
constexpr int toolkit_version = CUDA_VERSION;

0 commit comments

Comments
 (0)