@@ -286,11 +286,10 @@ void TransformAllocationSequenceToSpill(AllocationSequence& allocations,
286286} // namespace
287287
288288absl::StatusOr<MemorySpaceAssignment::AsyncCopyStats>
289- MemorySpaceAssignment::CalculateAsyncCopyStats () const {
289+ MemorySpaceAssignment::CalculateAsyncCopyStats (
290+ const HloDataflowAnalysis& dataflow_analysis) const {
290291 AsyncCopyStats stats;
291292 int64_t current_copies = 0 ;
292- TF_ASSIGN_OR_RETURN (std::unique_ptr<HloDataflowAnalysis> dataflow_analysis,
293- HloDataflowAnalysis::Run (*module_));
294293 for (const HloComputation* computation :
295294 module_->MakeNonfusionComputations ()) {
296295 for (HloInstruction* instruction : computation->instructions ()) {
@@ -305,7 +304,7 @@ MemorySpaceAssignment::CalculateAsyncCopyStats() const {
305304 HloOpcode::kSlice )) {
306305 current_copies--;
307306 int64_t size =
308- options_.size_fn (dataflow_analysis-> GetUniqueValueAt (instruction));
307+ options_.size_fn (dataflow_analysis. GetUniqueValueAt (instruction));
309308 if (instruction->shape ().layout ().memory_space () ==
310309 options_.alternate_memory_space ) {
311310 ++stats.num_prefetches ;
@@ -388,11 +387,13 @@ MemorySpaceAssignment::RunMemorySpaceAssignment(
388387 if (options_.cost_analysis ) {
389388 runtime_simulator.emplace (options_.cost_analysis ,
390389 options_.alternate_memory_space );
391- float estimated_time =
392- runtime_simulator->SimulateElapsedTimeWithoutAsyncCopyLikes (
393- hlo_live_range, allocations_);
394- VLOG (1 ) << " Estimated elapsed time without async copies (sec): "
395- << estimated_time;
390+ if (VLOG_IS_ON (1 )) {
391+ float estimated_time =
392+ runtime_simulator->SimulateElapsedTimeWithoutAsyncCopyLikes (
393+ hlo_live_range, allocations_);
394+ LOG (INFO) << " Estimated elapsed time without async copies (sec): "
395+ << estimated_time;
396+ }
396397 }
397398
398399 TF_RETURN_IF_ERROR (Process (hlo_live_range));
@@ -409,35 +410,34 @@ MemorySpaceAssignment::RunMemorySpaceAssignment(
409410 ScheduleAsynchronousCopies ();
410411 TF_RETURN_IF_ERROR (SimplifyGraph ());
411412 TF_RETURN_IF_ERROR (FixSchedule ());
412- TF_RETURN_IF_ERROR (ExportAndColorBuffers ());
413+ TF_ASSIGN_OR_RETURN (auto alias, HloAliasAnalysis::Run (module_));
414+ TF_RETURN_IF_ERROR (ExportAndColorBuffers (*alias));
413415 std::vector<int64_t > alt_mem_bytes_occupied;
414416 // alt_mem_bytes_occupied is used for logging in the RuntimeSimulator below.
415417 // We only populate it in VerifyAndExportHeapSimulatorTrace if the
416418 // RuntimeSimulator is present.
417419 TF_RETURN_IF_ERROR (VerifyAndExportHeapSimulatorTrace (
420+ *alias,
418421 runtime_simulator.has_value () ? &alt_mem_bytes_occupied : nullptr ));
419- if (runtime_simulator.has_value ()) {
420- float estimated_time = runtime_simulator->SimulateElapsedTime (
421- module_, allocations_, &alt_mem_bytes_occupied);
422- VLOG (1 ) << " Estimated elapsed time with async copies (sec): "
423- << estimated_time;
424- }
425422
426423 if (VLOG_IS_ON (3 )) {
427424 LOG (INFO) << " Module after memory space assignment: " ;
428425 XLA_LOG_LINES (INFO, module_->ToString ());
429426 }
430427 TF_CHECK_OK (module_->schedule ().Verify ());
431- TF_ASSIGN_OR_RETURN (AsyncCopyStats stats, CalculateAsyncCopyStats ());
432- VLOG (1 ) << " Maximum number of outstanding async copies/slices: "
433- << stats.max_outstanding_async_copies ;
434- VLOG (1 ) << " Number of prefetches: " << stats.num_prefetches
435- << " , in bytes: " << stats.prefetch_bytes ;
436- VLOG (1 ) << " Number of sliced prefetches: " << stats.num_sliced_prefetches
437- << " , consuming number of slices: "
438- << stats.num_sliced_prefetch_slices ;
439- VLOG (1 ) << " Number of evictions: " << stats.num_evictions
440- << " , in bytes: " << stats.eviction_bytes ;
428+ if (VLOG_IS_ON (1 )) {
429+ TF_ASSIGN_OR_RETURN (AsyncCopyStats stats,
430+ CalculateAsyncCopyStats (alias->dataflow_analysis ()));
431+ LOG (INFO) << " Maximum number of outstanding async copies/slices: "
432+ << stats.max_outstanding_async_copies ;
433+ LOG (INFO) << " Number of prefetches: " << stats.num_prefetches
434+ << " , in bytes: " << stats.prefetch_bytes ;
435+ LOG (INFO) << " Number of sliced prefetches: " << stats.num_sliced_prefetches
436+ << " , consuming number of slices: "
437+ << stats.num_sliced_prefetch_slices ;
438+ LOG (INFO) << " Number of evictions: " << stats.num_evictions
439+ << " , in bytes: " << stats.eviction_bytes ;
440+ }
441441
442442 return std::move (preset_assignments_);
443443}
@@ -539,15 +539,15 @@ absl::Status MemorySpaceAssignment::Process(
539539 return absl::OkStatus ();
540540}
541541
542- absl::Status MemorySpaceAssignment::ExportAndColorBuffers () {
542+ absl::Status MemorySpaceAssignment::ExportAndColorBuffers (
543+ const HloAliasAnalysis& alias_analysis) {
543544 VLOG (1 ) << " Exporting buffers..." ;
544- TF_ASSIGN_OR_RETURN (auto alias_analysis, HloAliasAnalysis::Run (module_));
545545 absl::flat_hash_map<int64_t , int64_t > seen_buffer_offsets;
546546 VLOG (3 ) << " Exported alternate memory allocations:" ;
547547 for (const auto & position_and_chunk : alternate_memory_assignments_) {
548548 const HloPosition& defining_position = position_and_chunk.first ;
549549 const HeapSimulator::Chunk& chunk = position_and_chunk.second ;
550- const HloBuffer& buffer = alias_analysis-> GetUniqueBufferAt (
550+ const HloBuffer& buffer = alias_analysis. GetUniqueBufferAt (
551551 defining_position.instruction , defining_position.index );
552552 auto seen_buffer_offset_it = seen_buffer_offsets.find (buffer.id ());
553553 if (seen_buffer_offset_it != seen_buffer_offsets.end ()) {
@@ -589,7 +589,7 @@ absl::Status MemorySpaceAssignment::ExportAndColorBuffers() {
589589 for (const auto & defining_position_and_chunk :
590590 preset_assignments_->chunks ()) {
591591 const HloPosition& defining_position = defining_position_and_chunk.first ;
592- for (auto & buffer : alias_analysis-> ComputeBuffersAt (
592+ for (auto & buffer : alias_analysis. ComputeBuffersAt (
593593 defining_position.instruction , defining_position.index )) {
594594 for (auto & value : buffer->values ()) {
595595 for (auto & position : value->positions ()) {
@@ -1049,12 +1049,11 @@ absl::Status MemorySpaceAssignment::FixSchedule() {
10491049}
10501050
10511051absl::Status MemorySpaceAssignment::VerifyAndExportHeapSimulatorTrace (
1052+ const HloAliasAnalysis& alias_analysis,
10521053 std::vector<int64_t >* alt_mem_bytes_occupied) {
10531054 VLOG (1 ) << " Verifying..." ;
1054- TF_ASSIGN_OR_RETURN (std::unique_ptr<HloAliasAnalysis> alias_analysis,
1055- HloAliasAnalysis::Run (module_));
10561055 TF_ASSIGN_OR_RETURN (std::unique_ptr<HloLiveRange> hlo_live_range,
1057- HloLiveRange::Run (module_->schedule (), * alias_analysis,
1056+ HloLiveRange::Run (module_->schedule (), alias_analysis,
10581057 module_->entry_computation ()));
10591058
10601059 BufferIntervalTree interval_tree;
@@ -1120,7 +1119,7 @@ absl::Status MemorySpaceAssignment::VerifyAndExportHeapSimulatorTrace(
11201119 const HloPosition& position = position_and_chunk.first ;
11211120 const HeapSimulator::Chunk& chunk = position_and_chunk.second ;
11221121 const HloBuffer& buffer =
1123- alias_analysis-> GetUniqueBufferAt (position.instruction , position.index );
1122+ alias_analysis. GetUniqueBufferAt (position.instruction , position.index );
11241123 CHECK (!seen_buffers.contains (buffer.id ()))
11251124 << " Multiple preset assignments for the same buffer: "
11261125 << buffer.ToString () << " , pos: " << position.ToString ()
0 commit comments