@@ -1043,19 +1043,6 @@ absl::Status RunFusionPasses(HloModule* hlo_module,
10431043 .Run (hlo_module)
10441044 .status ());
10451045
1046- if (hlo_module->config ().debug_options ().xla_gpu_collect_cost_model_stats ()) {
1047- GpuHloCostAnalysis::Options cost_analysis_options{
1048- shape_size_fn,
1049- /* per_second_rates=*/ {},
1050- /* min_latencies_seconds=*/ {},
1051- /* count_multiple_input_accesses=*/ true };
1052-
1053- HloPassPipeline post_fusion_analysis (" post_fusion_analysis" );
1054- post_fusion_analysis.AddPass <GpuCostModelStatsCollection>(
1055- gpu_device_info, cost_analysis_options);
1056- TF_RETURN_IF_ERROR (post_fusion_analysis.Run (hlo_module).status ());
1057- }
1058-
10591046 TF_RETURN_IF_ERROR (
10601047 HorizontalFusionPipeline (gpu_device_info).Run (hlo_module).status ());
10611048
@@ -2567,6 +2554,15 @@ absl::Status GpuCompiler::RunPreSchedulingPasses(
25672554 const se::DeviceDescription& gpu_device_info) {
25682555 HloPassPipeline pipeline (" pre-scheduling-passes" );
25692556 pipeline.AddPass <FusionWrapper>(gpu_device_info);
2557+ if (module ->config ().debug_options ().xla_gpu_collect_cost_model_stats ()) {
2558+ GpuHloCostAnalysis::Options cost_analysis_options{
2559+ ShapeSizeBytesFunction (),
2560+ /* per_second_rates=*/ {},
2561+ /* min_latencies_seconds=*/ {},
2562+ /* count_multiple_input_accesses=*/ true };
2563+ pipeline.AddPass <GpuCostModelStatsCollection>(gpu_device_info,
2564+ cost_analysis_options);
2565+ }
25702566 return pipeline.Run (module ).status ();
25712567}
25722568
0 commit comments