Skip to content

Commit e3c041d

Browse files
committed
add auto_grown_mutex for selected rows
1 parent 15db5a5 commit e3c041d

File tree

2 files changed

+5
-3
lines changed

2 files changed

+5
-3
lines changed

paddle/fluid/framework/selected_rows.cc

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ namespace paddle {
1818
namespace framework {
1919

2020
struct ReAllocateVisitor {
21-
ReAllocateVisitor(framework::Tensor* tensor, const framework::DDim& dims)
22-
: tensor_(tensor), dims_(dims) {}
21+
ReAllocateVisitor(const framework::DDim& dims, framework::Tensor* tensor)
22+
: dims_(dims), tensor_(tensor) {}
2323

2424
template <typename T>
2525
void operator()() const {
@@ -153,6 +153,7 @@ bool SelectedRows::Set(int64_t key, const framework::Tensor& value) {
153153
}
154154
PADDLE_ENFORCE_EQ(value.dims()[0], static_cast<size_t>(1),
155155
"The first dim of value should be 1.");
156+
std::lock_guard<std::mutex> lock(auto_grown_mutex_);
156157
auto index = Index(key);
157158
bool is_new_key = false;
158159
if (index == -1) {
@@ -164,7 +165,7 @@ bool SelectedRows::Set(int64_t key, const framework::Tensor& value) {
164165
auto dims = value_->dims();
165166
dims[0] = (dims[0] + 1) << 1;
166167
framework::VisitDataType(framework::ToDataType(value.type()),
167-
ReAllocateVisitor(value_.get(), dims));
168+
ReAllocateVisitor(dims, value_.get()));
168169
}
169170
}
170171

paddle/fluid/framework/selected_rows.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ class SelectedRows {
125125
Vector<int64_t> rows_;
126126
std::unique_ptr<Tensor> value_{nullptr};
127127
int64_t height_;
128+
std::mutex auto_grown_mutex_;
128129
};
129130

130131
/*

0 commit comments

Comments
 (0)