Skip to content

Commit d24c763

Browse files
authored
【Allocator】Update stategy of Tryalloc and AllocatorVisitor (PaddlePaddle#76523)
* update_stategy * update strategy
1 parent 20d9626 commit d24c763

File tree

9 files changed

+89
-32
lines changed

9 files changed

+89
-32
lines changed

paddle/common/flags.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2318,7 +2318,7 @@ PHI_DEFINE_EXPORTED_bool(use_accuracy_compatible_kernel,
23182318
/**
23192319
* Allocator Compact related FLAG
23202320
* Name: FLAGS_enable_compact_mem
2321-
* Since Version: 3.2.2
2321+
* Since Version: 3.3
23222322
* Value Range: bool, default=false
23232323
* Example:
23242324
* Note: whether start compact memory.
@@ -2329,7 +2329,7 @@ PHI_DEFINE_EXPORTED_bool(enable_compact_mem,
23292329
/**
23302330
* Allocator Compact related FLAG
23312331
* Name: FLAGS_max_reserved_threshold_in_gb
2332-
* Since Version: 3.2.2
2332+
* Since Version: 3.3
23332333
* Value Range: int64, default=70
23342334
* Example:
23352335
* Note: Threshold (GB) used in compact memory. Only reserved_mem greater than
@@ -2344,7 +2344,7 @@ PHI_DEFINE_EXPORTED_int64(
23442344
/**
23452345
* Allocator Compact related FLAG
23462346
* Name: FLAGS_cur_allocated_threshold_in_gb
2347-
* Since Version: 3.2.2
2347+
* Since Version: 3.3
23482348
* Value Range: int64, default=70
23492349
* Example:
23502350
* Note: Threshold (GB) used in compact memory. Only reserved_mem greater than
@@ -2359,7 +2359,7 @@ PHI_DEFINE_EXPORTED_int64(
23592359
/**
23602360
* Allocator Compact related FLAG
23612361
* Name: FLAGS_try_allocate
2362-
* Since Version: 3.2.2
2362+
* Since Version: 3.3
23632363
* Value Range: bool, default=false
23642364
* Example:
23652365
* Note: whether start compact memory.

paddle/fluid/pybind/eager_functions.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -543,7 +543,7 @@ PyObject* eager_api_run_custom_op(PyObject* self,
543543
PyObject* kwargs) {
544544
EAGER_TRY
545545
FLAGS_tensor_operants_mode = "phi";
546-
bool old_flag = FLAGS_enable_compact_mem;
546+
bool compact_flag_bak = FLAGS_enable_compact_mem;
547547
FLAGS_enable_compact_mem = false;
548548
if (paddle::OperantsManager::Instance().phi_operants.get() == nullptr) {
549549
paddle::OperantsManager::Instance().phi_operants =
@@ -881,7 +881,7 @@ PyObject* eager_api_run_custom_op(PyObject* self,
881881
if (FLAGS_check_cuda_error) [[unlikely]] {
882882
egr::CUDAErrorCheck("eager_api_run_custom_op " + op_type + " finish");
883883
}
884-
FLAGS_enable_compact_mem = old_flag;
884+
FLAGS_enable_compact_mem = compact_flag_bak;
885885
return ToPyObject(*ctx.AllMutableOutput());
886886
EAGER_CATCH_AND_THROW_RETURN_NULL
887887
}

paddle/phi/api/lib/api_gen_utils.cc

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -870,32 +870,45 @@ void CheckAndDoCompact(const std::vector<phi::MetaTensor*>& meta_tensors,
870870
auto NeedCompact = [&](const std::vector<phi::MetaTensor*>& meta_tensors) {
871871
if (max_reserved < FLAGS_max_reserved_threshold_in_gb << 30) return false;
872872
if (cur_allocated < FLAGS_cur_allocated_threshold_in_gb << 30) return false;
873-
const auto [max_free_size, total_free_size] =
873+
const auto [max_free_size, large_N_free_size] =
874874
paddle::memory::VmmMaxFreeSize(phi::GPUPlace(current_device_id),
875875
meta_tensors.size());
876876
const auto& [req_total_size, size_vec] = CalTensorSize(meta_tensors);
877+
VLOG(10) << "run api: " << api << "req_total_size: " << req_total_size
878+
<< ", max_free_size: " << max_free_size
879+
<< ", large_N_free_size: " << large_N_free_size
880+
<< ", max_reserved: " << max_reserved
881+
<< ", max_allocated: " << max_allocated
882+
<< ", cur_allocated: " << cur_allocated;
877883
if (req_total_size < max_free_size) return false;
878-
if (req_total_size > total_free_size) {
884+
if (req_total_size > large_N_free_size) {
879885
VLOG(1) << "Need Compact req_total_size: " << req_total_size
880-
<< ", total_free_size: " << total_free_size
881-
<< ", max_free_size: " << max_free_size;
886+
<< ", large_N_free_size: " << large_N_free_size
887+
<< ", max_free_size: " << max_free_size
888+
<< ", max_reserved: " << max_reserved
889+
<< ", max_allocated: " << max_allocated
890+
<< ", cur_allocated: " << cur_allocated;
882891
return true;
883892
}
884893
if (FLAGS_try_allocate) {
885894
auto alloc_succ = paddle::memory::TryAllocBatch(
886895
phi::GPUPlace(current_device_id), size_vec);
887-
VLOG(1) << "TryAllocBatch ret: " << !alloc_succ
896+
VLOG(1) << "TryAllocBatch ret: " << alloc_succ
888897
<< ", req_total_size: " << req_total_size
889-
<< ", total_free_size: " << total_free_size
890-
<< ", max_free_size: " << max_free_size;
898+
<< ", large_N_free_size: " << large_N_free_size
899+
<< ", max_free_size: " << max_free_size
900+
<< ", max_reserved: " << max_reserved
901+
<< ", max_allocated: " << max_allocated
902+
<< ", cur_allocated: " << cur_allocated;
891903
return !alloc_succ;
892904
}
893905
return false;
894906
};
895907

896908
if (NeedCompact(meta_tensors)) {
897909
VLOG(1) << "Before Compact max_reserved: " << max_reserved / divisor
898-
<< ", max_allocated: " << max_allocated / divisor;
910+
<< "GB, max_allocated: " << max_allocated / divisor
911+
<< "GB, cur_allocated: " << cur_allocated / divisor << "GB";
899912
paddle::memory::Compact(phi::GPUPlace(current_device_id));
900913
}
901914
#endif

paddle/phi/core/memory/allocation/allocator_facade.cc

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1013,11 +1013,29 @@ class AllocatorFacadePrivate {
10131013
val = 0;
10141014
}
10151015

1016-
if (val > 0 && FLAGS_use_virtual_memory_auto_growth) {
1016+
if (val > 0 && FLAGS_use_virtual_memory_auto_growth &&
1017+
!FLAGS_use_multi_scale_virtual_memory_auto_growth) {
10171018
auto cuda_allocator = std::make_shared<CUDAVirtualMemAllocator>(p);
10181019
cuda_allocators_[p][stream] =
10191020
std::make_shared<VirtualMemoryAutoGrowthBestFitAllocator>(
10201021
cuda_allocator, platform::GpuMinChunkSize(), p);
1022+
} else if (val > 0 && FLAGS_use_multi_scale_virtual_memory_auto_growth) {
1023+
std::cout << "enter init branch" << std::endl;
1024+
auto cuda_allocator_small = std::make_shared<CUDAVirtualMemAllocator>(p);
1025+
auto cuda_allocator_large = std::make_shared<CUDAVirtualMemAllocator>(p);
1026+
auto vmm_allocator_small =
1027+
std::make_shared<VirtualMemoryAutoGrowthBestFitAllocator>(
1028+
cuda_allocator_small, platform::GpuMinChunkSize(), p);
1029+
auto vmm_allocator_large =
1030+
std::make_shared<VirtualMemoryAutoGrowthBestFitAllocator>(
1031+
cuda_allocator_large, platform::GpuMinChunkSize(), p);
1032+
1033+
cuda_allocators_[p][stream] = std::make_shared<
1034+
VirtualMemoryAutoGrowthBestFitMultiScalePoolAllocator>(
1035+
vmm_allocator_small,
1036+
vmm_allocator_large,
1037+
platform::GpuMinChunkSize(),
1038+
p);
10211039
} else {
10221040
auto cuda_allocator = CreateCUDAAllocator(p);
10231041
if (FLAGS_use_auto_growth_v2) {

paddle/phi/core/memory/allocation/virtual_memory_auto_growth_best_fit_allocator.cc

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -363,10 +363,11 @@ bool VirtualMemoryAutoGrowthBestFitAllocator::TryAllocateBatch(
363363

364364
std::lock_guard<SpinLock> guard(spinlock_);
365365

366-
// copy free_blocks_ to shadow_blocks_
366+
// copy large N free_blocks_ to shadow_blocks_.
367367
std::map<std::pair<size_t, void *>, size_t> shadow_blocks;
368-
for (const auto &pair : free_blocks_) {
369-
shadow_blocks.emplace(pair.first, pair.first.first);
368+
auto it = free_blocks_.rbegin();
369+
for (int i = 0; i < sizes.size() && it != free_blocks_.rend(); ++i, ++it) {
370+
shadow_blocks.emplace(it->first, it->first.first);
370371
}
371372
for (size_t size : sizes) {
372373
size_t aligned_size = AlignedSize(size, alignment_);

paddle/phi/core/memory/mem_visitor.cc

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,23 +54,33 @@ void AllocatorVisitor::Visit(
5454
allocator->GetLargeAllocator()->Accept(this);
5555
}
5656

57+
void AllocatorComputeStreamVisitor::Visit(StreamSafeCUDAAllocator* allocator) {
58+
const std::vector<StreamSafeCUDAAllocator*>& allocators =
59+
allocator->GetAllocatorByPlace();
60+
assert(!allocators.empty());
61+
// NOTE(liujinnan): Currently, the Allocator initialization sequence is as
62+
// follows: the compute stream Allocator is initialized at program startup,
63+
// and then, when multiple streams are encountered at runtime, additional
64+
// Allocators are created and added to the end of the `allocator_map_` in
65+
// `StreamSafeCUDAAllocator`. Therefore, we can use the first allocator in
66+
// `allocator_map_` as the compute stream allocator. Although this approach is
67+
// somewhat ugly and may not be robust, it is currently effective.
68+
allocators[0]->GetUnderLyingAllocator()->Accept(this);
69+
}
70+
5771
void FreeMemoryMetricsVisitor::Visit(
5872
VirtualMemoryAutoGrowthBestFitAllocator* allocator) {
5973
auto [large_size, sum_size] =
6074
allocator->SumLargestFreeBlockSizes(nums_blocks_);
6175
large_size_ = std::max(large_size_, large_size);
6276
sum_size_ = std::max(sum_size_, sum_size);
63-
VLOG(1) << "Visit VirtualMemoryAutoGrowthBestFitAllocator large_free_size:"
64-
<< large_size_ << " sum_free_size:" << sum_size_;
6577
}
6678

6779
void TryAllocVisitor::Visit(
6880
VirtualMemoryAutoGrowthBestFitAllocator* allocator) {
6981
// TODO(liujinnan): More detailed handling of multi-stream and MultiScalePool
7082
// scenarios.
7183
is_try_alloc_success_ |= allocator->TryAllocateBatch(sizes_);
72-
VLOG(1) << "Visit VirtualMemoryAutoGrowthBestFitAllocator try_alloc_result:"
73-
<< is_try_alloc_success_;
7484
}
7585

7686
void VMMFreeBlocksInfoVisitor::Visit(

paddle/phi/core/memory/mem_visitor.h

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,16 @@ class AllocatorVisitor : public AllocatorVisitorReqImpl {
8383
};
8484

8585
#ifdef PADDLE_WITH_CUDA
86+
/**
87+
* @brief AllocatorComputeStreamVisitor is a Concrete Visitor class designed to
88+
* only visit compute stream allocators.
89+
*/
90+
class AllocatorComputeStreamVisitor : public AllocatorVisitor {
91+
public:
92+
using AllocatorVisitor::Visit;
93+
void Visit(StreamSafeCUDAAllocator* allocator) override;
94+
};
95+
8696
/**
8797
* @brief FreeMemoryMetricsVisitor is a Concrete Visitor class designed to
8898
* inspect allocators for free memory information.
@@ -92,8 +102,9 @@ class AllocatorVisitor : public AllocatorVisitorReqImpl {
92102
* it provides specialized logic for the
93103
* VirtualMemoryAutoGrowthBestFitAllocator.
94104
*/
95-
class FreeMemoryMetricsVisitor : public AllocatorVisitor {
105+
class FreeMemoryMetricsVisitor : public AllocatorComputeStreamVisitor {
96106
public:
107+
using AllocatorComputeStreamVisitor::Visit;
97108
/**
98109
* @brief Constructor for FreeMemoryMetricsVisitor.
99110
* @param nums_blocks The number of largest free blocks to potentially track
@@ -139,7 +150,9 @@ class FreeMemoryMetricsVisitor : public AllocatorVisitor {
139150
* (typically VirtualMemoryAutoGrowthBestFitAllocator) and record if all
140151
* attempts were successful.
141152
*/
142-
class TryAllocVisitor : public AllocatorVisitor {
153+
class TryAllocVisitor : public AllocatorComputeStreamVisitor {
154+
using AllocatorComputeStreamVisitor::Visit;
155+
143156
public:
144157
/**
145158
* @brief Constructor.
@@ -183,13 +196,10 @@ class TryAllocVisitor : public AllocatorVisitor {
183196
* internal state (the list of free memory blocks) and extract key information
184197
* (size and address) for external analysis or debugging.
185198
*/
186-
class VMMFreeBlocksInfoVisitor : public AllocatorVisitor {
187-
public:
188-
/**
189-
* @brief Default Constructor.
190-
*/
191-
VMMFreeBlocksInfoVisitor() {}
199+
class VMMFreeBlocksInfoVisitor : public AllocatorComputeStreamVisitor {
200+
using AllocatorComputeStreamVisitor::Visit;
192201

202+
public:
193203
/**
194204
* @brief Retrieves the collected information about the free memory blocks.
195205
*

test/cpp/phi/memory/gen_compact_test.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
1+
/* Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
22
33
Licensed under the Apache License, Version 2.0 (the "License");
44
you may not use this file except in compliance with the License.
@@ -40,7 +40,7 @@ class CheckAndDoCompactTest : public ::testing::Test {
4040
FLAGS_try_allocate = true;
4141
FLAGS_use_multi_scale_virtual_memory_auto_growth = true;
4242
FLAGS_vmm_small_pool_size_in_mb = 2;
43-
FLAGS_v = 4;
43+
FLAGS_v = 10;
4444
}
4545

4646
void TearDown() override { meta_tensors_.clear(); }

test/legacy_test/test_multi_scale_pool_allocator.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,11 @@ def allocate_cmds(self, cmds):
6363
print(
6464
f"reserved = {paddle_reserved2} allocated = {paddle_allocated2} auto growth = {paddle_reserved2 - paddle_reserved1} max_allocated = {paddle_max_allocated} max_reserved = {paddle_max_reserved}"
6565
)
66+
# for multi stream
67+
stream = paddle.device.cuda.Stream()
68+
with paddle.device.cuda.stream_guard(stream):
69+
x = paddle.empty([int(1 * 1024 * 1024 * 1024)], dtype=paddle.uint8)
70+
del x
6671
return params
6772

6873
def test_multi_scale_alloc_free(self):

0 commit comments

Comments
 (0)