Skip to content

Commit ab8bcb3

Browse files
[XLA:TPU] Reuse same Alias Analysis object in RunMemorySpaceAssignment
PiperOrigin-RevId: 707502636
1 parent 7b8e0d0 commit ab8bcb3

File tree

2 files changed

+38
-36
lines changed

2 files changed

+38
-36
lines changed

xla/service/memory_space_assignment/memory_space_assignment.cc

Lines changed: 33 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -286,11 +286,10 @@ void TransformAllocationSequenceToSpill(AllocationSequence& allocations,
286286
} // namespace
287287

288288
absl::StatusOr<MemorySpaceAssignment::AsyncCopyStats>
289-
MemorySpaceAssignment::CalculateAsyncCopyStats() const {
289+
MemorySpaceAssignment::CalculateAsyncCopyStats(
290+
const HloDataflowAnalysis& dataflow_analysis) const {
290291
AsyncCopyStats stats;
291292
int64_t current_copies = 0;
292-
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloDataflowAnalysis> dataflow_analysis,
293-
HloDataflowAnalysis::Run(*module_));
294293
for (const HloComputation* computation :
295294
module_->MakeNonfusionComputations()) {
296295
for (HloInstruction* instruction : computation->instructions()) {
@@ -305,7 +304,7 @@ MemorySpaceAssignment::CalculateAsyncCopyStats() const {
305304
HloOpcode::kSlice)) {
306305
current_copies--;
307306
int64_t size =
308-
options_.size_fn(dataflow_analysis->GetUniqueValueAt(instruction));
307+
options_.size_fn(dataflow_analysis.GetUniqueValueAt(instruction));
309308
if (instruction->shape().layout().memory_space() ==
310309
options_.alternate_memory_space) {
311310
++stats.num_prefetches;
@@ -388,11 +387,13 @@ MemorySpaceAssignment::RunMemorySpaceAssignment(
388387
if (options_.cost_analysis) {
389388
runtime_simulator.emplace(options_.cost_analysis,
390389
options_.alternate_memory_space);
391-
float estimated_time =
392-
runtime_simulator->SimulateElapsedTimeWithoutAsyncCopyLikes(
393-
hlo_live_range, allocations_);
394-
VLOG(1) << "Estimated elapsed time without async copies (sec): "
395-
<< estimated_time;
390+
if (VLOG_IS_ON(1)) {
391+
float estimated_time =
392+
runtime_simulator->SimulateElapsedTimeWithoutAsyncCopyLikes(
393+
hlo_live_range, allocations_);
394+
LOG(INFO) << "Estimated elapsed time without async copies (sec): "
395+
<< estimated_time;
396+
}
396397
}
397398

398399
TF_RETURN_IF_ERROR(Process(hlo_live_range));
@@ -409,35 +410,34 @@ MemorySpaceAssignment::RunMemorySpaceAssignment(
409410
ScheduleAsynchronousCopies();
410411
TF_RETURN_IF_ERROR(SimplifyGraph());
411412
TF_RETURN_IF_ERROR(FixSchedule());
412-
TF_RETURN_IF_ERROR(ExportAndColorBuffers());
413+
TF_ASSIGN_OR_RETURN(auto alias, HloAliasAnalysis::Run(module_));
414+
TF_RETURN_IF_ERROR(ExportAndColorBuffers(*alias));
413415
std::vector<int64_t> alt_mem_bytes_occupied;
414416
// alt_mem_bytes_occupied is used for logging in the RuntimeSimulator below.
415417
// We only populate it in VerifyAndExportHeapSimulatorTrace if the
416418
// RuntimeSimulator is present.
417419
TF_RETURN_IF_ERROR(VerifyAndExportHeapSimulatorTrace(
420+
*alias,
418421
runtime_simulator.has_value() ? &alt_mem_bytes_occupied : nullptr));
419-
if (runtime_simulator.has_value()) {
420-
float estimated_time = runtime_simulator->SimulateElapsedTime(
421-
module_, allocations_, &alt_mem_bytes_occupied);
422-
VLOG(1) << "Estimated elapsed time with async copies (sec): "
423-
<< estimated_time;
424-
}
425422

426423
if (VLOG_IS_ON(3)) {
427424
LOG(INFO) << "Module after memory space assignment: ";
428425
XLA_LOG_LINES(INFO, module_->ToString());
429426
}
430427
TF_CHECK_OK(module_->schedule().Verify());
431-
TF_ASSIGN_OR_RETURN(AsyncCopyStats stats, CalculateAsyncCopyStats());
432-
VLOG(1) << "Maximum number of outstanding async copies/slices: "
433-
<< stats.max_outstanding_async_copies;
434-
VLOG(1) << "Number of prefetches: " << stats.num_prefetches
435-
<< ", in bytes: " << stats.prefetch_bytes;
436-
VLOG(1) << "Number of sliced prefetches: " << stats.num_sliced_prefetches
437-
<< ", consuming number of slices: "
438-
<< stats.num_sliced_prefetch_slices;
439-
VLOG(1) << "Number of evictions: " << stats.num_evictions
440-
<< ", in bytes: " << stats.eviction_bytes;
428+
if (VLOG_IS_ON(1)) {
429+
TF_ASSIGN_OR_RETURN(AsyncCopyStats stats,
430+
CalculateAsyncCopyStats(alias->dataflow_analysis()));
431+
LOG(INFO) << "Maximum number of outstanding async copies/slices: "
432+
<< stats.max_outstanding_async_copies;
433+
LOG(INFO) << "Number of prefetches: " << stats.num_prefetches
434+
<< ", in bytes: " << stats.prefetch_bytes;
435+
LOG(INFO) << "Number of sliced prefetches: " << stats.num_sliced_prefetches
436+
<< ", consuming number of slices: "
437+
<< stats.num_sliced_prefetch_slices;
438+
LOG(INFO) << "Number of evictions: " << stats.num_evictions
439+
<< ", in bytes: " << stats.eviction_bytes;
440+
}
441441

442442
return std::move(preset_assignments_);
443443
}
@@ -539,15 +539,15 @@ absl::Status MemorySpaceAssignment::Process(
539539
return absl::OkStatus();
540540
}
541541

542-
absl::Status MemorySpaceAssignment::ExportAndColorBuffers() {
542+
absl::Status MemorySpaceAssignment::ExportAndColorBuffers(
543+
const HloAliasAnalysis& alias_analysis) {
543544
VLOG(1) << "Exporting buffers...";
544-
TF_ASSIGN_OR_RETURN(auto alias_analysis, HloAliasAnalysis::Run(module_));
545545
absl::flat_hash_map<int64_t, int64_t> seen_buffer_offsets;
546546
VLOG(3) << "Exported alternate memory allocations:";
547547
for (const auto& position_and_chunk : alternate_memory_assignments_) {
548548
const HloPosition& defining_position = position_and_chunk.first;
549549
const HeapSimulator::Chunk& chunk = position_and_chunk.second;
550-
const HloBuffer& buffer = alias_analysis->GetUniqueBufferAt(
550+
const HloBuffer& buffer = alias_analysis.GetUniqueBufferAt(
551551
defining_position.instruction, defining_position.index);
552552
auto seen_buffer_offset_it = seen_buffer_offsets.find(buffer.id());
553553
if (seen_buffer_offset_it != seen_buffer_offsets.end()) {
@@ -589,7 +589,7 @@ absl::Status MemorySpaceAssignment::ExportAndColorBuffers() {
589589
for (const auto& defining_position_and_chunk :
590590
preset_assignments_->chunks()) {
591591
const HloPosition& defining_position = defining_position_and_chunk.first;
592-
for (auto& buffer : alias_analysis->ComputeBuffersAt(
592+
for (auto& buffer : alias_analysis.ComputeBuffersAt(
593593
defining_position.instruction, defining_position.index)) {
594594
for (auto& value : buffer->values()) {
595595
for (auto& position : value->positions()) {
@@ -1049,12 +1049,11 @@ absl::Status MemorySpaceAssignment::FixSchedule() {
10491049
}
10501050

10511051
absl::Status MemorySpaceAssignment::VerifyAndExportHeapSimulatorTrace(
1052+
const HloAliasAnalysis& alias_analysis,
10521053
std::vector<int64_t>* alt_mem_bytes_occupied) {
10531054
VLOG(1) << "Verifying...";
1054-
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
1055-
HloAliasAnalysis::Run(module_));
10561055
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloLiveRange> hlo_live_range,
1057-
HloLiveRange::Run(module_->schedule(), *alias_analysis,
1056+
HloLiveRange::Run(module_->schedule(), alias_analysis,
10581057
module_->entry_computation()));
10591058

10601059
BufferIntervalTree interval_tree;
@@ -1120,7 +1119,7 @@ absl::Status MemorySpaceAssignment::VerifyAndExportHeapSimulatorTrace(
11201119
const HloPosition& position = position_and_chunk.first;
11211120
const HeapSimulator::Chunk& chunk = position_and_chunk.second;
11221121
const HloBuffer& buffer =
1123-
alias_analysis->GetUniqueBufferAt(position.instruction, position.index);
1122+
alias_analysis.GetUniqueBufferAt(position.instruction, position.index);
11241123
CHECK(!seen_buffers.contains(buffer.id()))
11251124
<< "Multiple preset assignments for the same buffer: "
11261125
<< buffer.ToString() << ", pos: " << position.ToString()

xla/service/memory_space_assignment/memory_space_assignment.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,7 @@ Useful logging and error messages
190190
#include "absl/status/statusor.h"
191191
#include "absl/types/span.h"
192192
#include "xla/hlo/analysis/hlo_alias_analysis.h"
193+
#include "xla/hlo/analysis/hlo_dataflow_analysis.h"
193194
#include "xla/hlo/ir/hlo_instruction.h"
194195
#include "xla/hlo/utils/hlo_live_range.h"
195196
#include "xla/service/buffer_value.h"
@@ -305,7 +306,8 @@ class MemorySpaceAssignment {
305306
const HloAliasAnalysis& alias_analysis, const Options& options);
306307

307308
// Calculates asynchronous copy statistics.
308-
absl::StatusOr<AsyncCopyStats> CalculateAsyncCopyStats() const;
309+
absl::StatusOr<AsyncCopyStats> CalculateAsyncCopyStats(
310+
const HloDataflowAnalysis& dataflow_analysis) const;
309311

310312
// Verify that allocations_ are free of overlapping Allocations in time and
311313
// space. This is a post-processing step called after all allocations have
@@ -318,6 +320,7 @@ class MemorySpaceAssignment {
318320
// If alt_mem_bytes_occupied is not null, it will be populated with the number
319321
// of bytes occupied in the alternate memory space at each instruction time.
320322
absl::Status VerifyAndExportHeapSimulatorTrace(
323+
const HloAliasAnalysis& alias_analysis,
321324
std::vector<int64_t>* alt_mem_bytes_occupied = nullptr);
322325

323326
protected:
@@ -372,7 +375,7 @@ class MemorySpaceAssignment {
372375

373376
// Export the alternate memory assignments to the PresetAssignments and color
374377
// the HLO graph with the determined memory spaces.
375-
absl::Status ExportAndColorBuffers();
378+
absl::Status ExportAndColorBuffers(const HloAliasAnalysis& alias_analysis);
376379

377380
// Schedules asynchronous copies and ensures that the CopyStarts and their
378381
// corresponding CopyDones follow the same order.

0 commit comments

Comments
 (0)