Skip to content

Commit 7bbb166

Browse files
[XLA:GPU] Run GpuCostModelStatsCollection prior scheduling.
PiperOrigin-RevId: 707475368
1 parent b6c3124 commit 7bbb166

File tree

1 file changed

+9
-13
lines changed

1 file changed

+9
-13
lines changed

xla/service/gpu/gpu_compiler.cc

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1043,19 +1043,6 @@ absl::Status RunFusionPasses(HloModule* hlo_module,
10431043
.Run(hlo_module)
10441044
.status());
10451045

1046-
if (hlo_module->config().debug_options().xla_gpu_collect_cost_model_stats()) {
1047-
GpuHloCostAnalysis::Options cost_analysis_options{
1048-
shape_size_fn,
1049-
/*per_second_rates=*/{},
1050-
/*min_latencies_seconds=*/{},
1051-
/*count_multiple_input_accesses=*/true};
1052-
1053-
HloPassPipeline post_fusion_analysis("post_fusion_analysis");
1054-
post_fusion_analysis.AddPass<GpuCostModelStatsCollection>(
1055-
gpu_device_info, cost_analysis_options);
1056-
TF_RETURN_IF_ERROR(post_fusion_analysis.Run(hlo_module).status());
1057-
}
1058-
10591046
TF_RETURN_IF_ERROR(
10601047
HorizontalFusionPipeline(gpu_device_info).Run(hlo_module).status());
10611048

@@ -2567,6 +2554,15 @@ absl::Status GpuCompiler::RunPreSchedulingPasses(
25672554
const se::DeviceDescription& gpu_device_info) {
25682555
HloPassPipeline pipeline("pre-scheduling-passes");
25692556
pipeline.AddPass<FusionWrapper>(gpu_device_info);
2557+
if (module->config().debug_options().xla_gpu_collect_cost_model_stats()) {
2558+
GpuHloCostAnalysis::Options cost_analysis_options{
2559+
ShapeSizeBytesFunction(),
2560+
/*per_second_rates=*/{},
2561+
/*min_latencies_seconds=*/{},
2562+
/*count_multiple_input_accesses=*/true};
2563+
pipeline.AddPass<GpuCostModelStatsCollection>(gpu_device_info,
2564+
cost_analysis_options);
2565+
}
25702566
return pipeline.Run(module).status();
25712567
}
25722568

0 commit comments

Comments
 (0)