Skip to content

Commit 4f01de6

Browse files
committed
Merge remote-tracking branch 'origin/develop' into feature/ir_inplace_pass
2 parents 5cab99a + 46a6cac commit 4f01de6

File tree

8 files changed

+191
-20
lines changed

8 files changed

+191
-20
lines changed

paddle/fluid/framework/scope.cc

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,7 @@ limitations under the License. */
2222
#include "paddle/fluid/framework/threadpool.h"
2323
#include "paddle/fluid/string/printf.h"
2424

25-
DEFINE_bool(benchmark, false,
26-
"Doing memory benchmark. It will make deleting scope synchronized, "
27-
"and add some memory usage logs."
28-
"Default cuda is asynchronous device, set to True will"
29-
"force op run in synchronous mode.");
25+
DECLARE_bool(benchmark);
3026

3127
DEFINE_bool(
3228
eager_delete_scope, true,

paddle/fluid/inference/api/analysis_predictor.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,8 @@ namespace {
5858
bool IsPersistable(const framework::VarDesc *var) {
5959
if (var->Persistable() &&
6060
var->GetType() != framework::proto::VarType::FEED_MINIBATCH &&
61-
var->GetType() != framework::proto::VarType::FETCH_LIST) {
61+
var->GetType() != framework::proto::VarType::FETCH_LIST &&
62+
var->GetType() != framework::proto::VarType::RAW) {
6263
return true;
6364
}
6465
return false;

paddle/fluid/memory/allocation/legacy_allocator.cc

Lines changed: 64 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ DEFINE_bool(init_allocated_mem, false,
3535
"To find this error in time, we use init_allocated_mem to indicate "
3636
"that initializing the allocated memory with a small value "
3737
"during unit testing.");
38+
DECLARE_bool(benchmark);
3839
DECLARE_double(fraction_of_gpu_memory_to_use);
3940

4041
namespace paddle {
@@ -59,11 +60,6 @@ size_t memory_usage(const platform::Place &p);
5960

6061
using BuddyAllocator = detail::BuddyAllocator;
6162

62-
std::unordered_map</*device id*/ int,
63-
std::pair</*current memory usage*/ uint64_t,
64-
/*peak memory usage*/ uint64_t>>
65-
gpu_mem_info;
66-
6763
BuddyAllocator *GetCPUBuddyAllocator() {
6864
// We tried thread_local for inference::RNN1 model, but that not works much
6965
// for multi-thread test.
@@ -144,6 +140,8 @@ BuddyAllocator *GetGPUBuddyAllocator(int gpu_id) {
144140
devices = platform::GetSelectedDevices();
145141
int gpu_num = devices.size();
146142

143+
allocation::GPUMemMonitor.Initialize(devices.size());
144+
147145
a_arr = new BuddyAllocator *[gpu_num];
148146
for (size_t i = 0; i < devices.size(); ++i) {
149147
int dev_id = devices[i];
@@ -204,12 +202,7 @@ void *Alloc<platform::CUDAPlace>(const platform::CUDAPlace &place,
204202
<< string::HumanReadableSize(Used<platform::CUDAPlace>(place));
205203
platform::SetDeviceId(cur_dev);
206204
} else {
207-
gpu_mem_info[place.device].first += size;
208-
if (gpu_mem_info[place.device].first > gpu_mem_info[place.device].second) {
209-
gpu_mem_info[place.device].second = gpu_mem_info[place.device].first;
210-
VLOG(3) << "device: " << place.device << " peak memory usage : "
211-
<< (gpu_mem_info[place.device].second >> 20) << " MiB";
212-
}
205+
if (FLAGS_benchmark) allocation::GPUMemMonitor.Add(place.device, size);
213206
if (FLAGS_init_allocated_mem) {
214207
cudaMemset(ptr, 0xEF, size);
215208
}
@@ -225,7 +218,7 @@ void Free<platform::CUDAPlace>(const platform::CUDAPlace &place, void *p,
225218
size_t size) {
226219
#ifdef PADDLE_WITH_CUDA
227220
GetGPUBuddyAllocator(place.device)->Free(p);
228-
gpu_mem_info[place.device].first -= size;
221+
if (FLAGS_benchmark) allocation::GPUMemMonitor.Minus(place.device, size);
229222
#else
230223
PADDLE_THROW("'CUDAPlace' is not supported in CPU only device.");
231224
#endif
@@ -335,6 +328,8 @@ size_t Usage::operator()(const platform::CUDAPinnedPlace &cuda_pinned) const {
335328

336329
namespace allocation {
337330

331+
LegacyMemMonitor GPUMemMonitor;
332+
338333
Allocation *LegacyAllocator::AllocateImpl(size_t size, Allocator::Attr attr) {
339334
void *ptr = boost::apply_visitor(legacy::AllocVisitor(size), place_);
340335
return new Allocation(ptr, size, place_);
@@ -346,6 +341,63 @@ void LegacyAllocator::Free(Allocation *allocation) {
346341
allocation->place());
347342
delete allocation;
348343
}
344+
345+
bool MemInfo::Add(const size_t &size) {
346+
std::lock_guard<std::mutex> lock(mutex_);
347+
usage_ += size;
348+
bool peak_point = usage_ > peak_usage_;
349+
if (peak_point) peak_usage_ = usage_;
350+
return peak_point;
351+
}
352+
353+
void MemInfo::Minus(const size_t &size) {
354+
std::lock_guard<std::mutex> lock(mutex_);
355+
usage_ -= size;
356+
}
357+
358+
uint64_t MemInfo::GetPeakUsage() { return peak_usage_; }
359+
360+
LegacyMemMonitor::~LegacyMemMonitor() {
361+
for (auto &item : gpu_mem_info_) delete item.second;
362+
}
363+
364+
void LegacyMemMonitor::Initialize(const int &device_num) {
365+
for (auto i = 0; i < device_num; ++i) {
366+
gpu_mem_info_[i] = new MemInfo();
367+
}
368+
}
369+
370+
void LegacyMemMonitor::Add(const int &device, const size_t &size) {
371+
if (gpu_mem_info_[device]->Add(size)) {
372+
VLOG(3) << "#LegacyMemMonitor# device: " << device
373+
<< " peak memory usage : "
374+
<< (gpu_mem_info_[device]->GetPeakUsage() >> 20) << " MiB";
375+
}
376+
}
377+
378+
void LegacyMemMonitor::Minus(const int &device, const size_t &size) {
379+
gpu_mem_info_[device]->Minus(size);
380+
}
381+
382+
uint64_t LegacyMemMonitor::GetMemUsage(const int &device) {
383+
return gpu_mem_info_.find(device) == gpu_mem_info_.end()
384+
? 0
385+
: gpu_mem_info_[device]->GetPeakUsage();
386+
}
387+
388+
void LegacyMemMonitor::PrintMemUsage() {
389+
std::vector<int> devices;
390+
for (const auto &item : gpu_mem_info_) {
391+
devices.emplace_back(item.first);
392+
}
393+
std::sort(devices.begin(), devices.end());
394+
for (const auto &device : devices) {
395+
std::cout << "Device : " << device << " Peak Memory Usage : "
396+
<< (gpu_mem_info_[device]->GetPeakUsage() >> 20) << " MiB"
397+
<< std::endl;
398+
}
399+
}
400+
349401
} // namespace allocation
350402
} // namespace memory
351403
} // namespace paddle

paddle/fluid/memory/allocation/legacy_allocator.h

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,59 @@
1313
// limitations under the License.
1414

1515
#pragma once
16+
#include <algorithm>
17+
#include <mutex> // NOLINT
18+
#include <unordered_map>
19+
#include <utility>
20+
#include <vector>
1621
#include "paddle/fluid/memory/allocation/allocator.h"
1722
#include "paddle/fluid/platform/place.h"
1823
namespace paddle {
1924
namespace memory {
2025
namespace allocation {
2126

27+
class MemInfo {
28+
public:
29+
MemInfo() : usage_(0), peak_usage_(0) {}
30+
MemInfo(const MemInfo &) = delete;
31+
MemInfo &operator=(const MemInfo &) = delete;
32+
33+
// return a flag to indicate current operation will create a peak point or not
34+
bool Add(const size_t &);
35+
void Minus(const size_t &);
36+
37+
uint64_t GetPeakUsage();
38+
39+
private:
40+
/* current memory usage*/
41+
uint64_t usage_;
42+
uint64_t peak_usage_;
43+
std::mutex mutex_;
44+
};
45+
46+
class LegacyMemMonitor {
47+
public:
48+
// used to store the GPU memory usage of each devices
49+
using MemUsage = std::unordered_map</*device id*/ int,
50+
/*mem usage info node*/ MemInfo *>;
51+
52+
MemUsage GetMemUsageInfo() { return gpu_mem_info_; }
53+
~LegacyMemMonitor();
54+
55+
void Initialize(const int &);
56+
void Add(const int &, const size_t &);
57+
void Minus(const int &, const size_t &);
58+
59+
uint64_t GetMemUsage(const int &);
60+
61+
void PrintMemUsage();
62+
63+
protected:
64+
MemUsage gpu_mem_info_;
65+
};
66+
67+
extern LegacyMemMonitor GPUMemMonitor;
68+
2269
class LegacyAllocatorPrivate;
2370
class LegacyAllocator : public Allocator {
2471
public:

paddle/fluid/operators/batch_norm_op.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -589,8 +589,10 @@ class BatchNormGradMaker : public framework::SingleGradOpDescMaker {
589589
op->SetInput("SavedVariance", Output("SavedVariance"));
590590

591591
// used when setting use_global_stats True during training
592-
op->SetInput("Mean", Output("MeanOut"));
593-
op->SetInput("Variance", Output("VarianceOut"));
592+
if (boost::get<bool>(GetAttr("use_global_stats"))) {
593+
op->SetInput("Mean", Output("MeanOut"));
594+
op->SetInput("Variance", Output("VarianceOut"));
595+
}
594596

595597
op->SetAttrMap(Attrs());
596598

paddle/fluid/platform/place.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,12 @@ limitations under the License. */
1414

1515
#include "paddle/fluid/platform/place.h"
1616

17+
DEFINE_bool(benchmark, false,
18+
"Doing memory benchmark. It will make deleting scope synchronized, "
19+
"and add some memory usage logs."
20+
"Default cuda is asynchronous device, set to True will"
21+
"force op run in synchronous mode.");
22+
1723
namespace paddle {
1824
namespace platform {
1925

paddle/fluid/pybind/pybind.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ limitations under the License. */
3737
#include "paddle/fluid/framework/version.h"
3838
#include "paddle/fluid/imperative/layer.h"
3939
#include "paddle/fluid/memory/allocation/allocator_strategy.h"
40+
#include "paddle/fluid/memory/allocation/legacy_allocator.h"
4041
#include "paddle/fluid/operators/activation_op.h"
4142
#include "paddle/fluid/operators/py_func_op.h"
4243
#include "paddle/fluid/operators/reader/lod_tensor_blocking_queue.h"
@@ -127,6 +128,13 @@ PYBIND11_MODULE(core, m) {
127128
m.add_object("_cleanup",
128129
py::capsule([]() { ScopePool::Instance().Clear(); }));
129130

131+
m.def("get_mem_usage", [](int device) {
132+
return memory::allocation::GPUMemMonitor.GetMemUsage(device);
133+
});
134+
135+
m.def("print_mem_usage",
136+
[]() { return memory::allocation::GPUMemMonitor.PrintMemUsage(); });
137+
130138
py::class_<imperative::VarBase>(m, "VarBase", R"DOC()DOC")
131139
// .def(py::init<>())
132140
.def(py::init<bool>(), py::arg("stop_gradient") = false)
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import print_function
16+
17+
import unittest
18+
import os
19+
os.environ['FLAGS_benchmark'] = 'True'
20+
21+
import numpy
22+
import paddle.fluid.core as core
23+
from paddle.fluid.executor import Executor
24+
from paddle.fluid.layers import mul, data
25+
26+
27+
class TestPeakMemoryMonitoring(unittest.TestCase):
28+
def test_mul(self):
29+
30+
a = data(name='a', shape=[784], dtype='float32')
31+
b = data(
32+
name='b',
33+
shape=[784, 100],
34+
dtype='float32',
35+
append_batch_size=False)
36+
out = mul(x=a, y=b)
37+
38+
if core.is_compiled_with_cuda():
39+
place = core.CUDAPlace(0)
40+
41+
a_np = numpy.random.random((100, 784)).astype('float32')
42+
b_np = numpy.random.random((784, 100)).astype('float32')
43+
self.assertEqual(0, core.get_mem_usage(0))
44+
exe = Executor(place)
45+
outs = exe.run(feed={'a': a_np, 'b': b_np}, fetch_list=[out])
46+
out = outs[0]
47+
#disable this assert since ctest will ignore the os.environ setting
48+
#self.assertGreater(core.get_mem_usage(0), 0)
49+
50+
raised = False
51+
try:
52+
core.print_mem_usage()
53+
except:
54+
raised = True
55+
self.assertFalse(raised, 'Exception raised')
56+
57+
58+
if __name__ == '__main__':
59+
unittest.main()

0 commit comments

Comments
 (0)