Skip to content

Commit 654f5d3

Browse files
authored
Merge pull request #11012 from jacquesqiao/add-auto_grown_mutex
add auto_grown_mutex for selected rows
2 parents c95cd47 + fa2079b commit 654f5d3

File tree

2 files changed

+10
-4
lines changed

2 files changed

+10
-4
lines changed

paddle/fluid/framework/selected_rows.cc

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ namespace paddle {
1818
namespace framework {
1919

2020
struct ReAllocateVisitor {
21-
ReAllocateVisitor(framework::Tensor* tensor, const framework::DDim& dims)
22-
: tensor_(tensor), dims_(dims) {}
21+
ReAllocateVisitor(const framework::DDim& dims, framework::Tensor* tensor)
22+
: dims_(dims), tensor_(tensor) {}
2323

2424
template <typename T>
2525
void operator()() const {
@@ -34,8 +34,8 @@ struct ReAllocateVisitor {
3434
tensor_->ShareDataWith(cpu_tensor);
3535
}
3636

37-
framework::Tensor* tensor_;
3837
framework::DDim dims_;
38+
framework::Tensor* tensor_;
3939
};
4040

4141
struct TensorCopyVisitor {
@@ -158,6 +158,7 @@ bool SelectedRows::Set(int64_t key, const framework::Tensor& value) {
158158
}
159159
PADDLE_ENFORCE_EQ(value.dims()[0], static_cast<size_t>(1),
160160
"The first dim of value should be 1.");
161+
std::lock_guard<std::mutex> lock(*auto_grown_mutex_.get());
161162
auto index = Index(key);
162163
bool is_new_key = false;
163164
if (index == -1) {
@@ -169,7 +170,7 @@ bool SelectedRows::Set(int64_t key, const framework::Tensor& value) {
169170
auto dims = value_->dims();
170171
dims[0] = (dims[0] + 1) << 1;
171172
framework::VisitDataType(framework::ToDataType(value.type()),
172-
ReAllocateVisitor(value_.get(), dims));
173+
ReAllocateVisitor(dims, value_.get()));
173174
}
174175
}
175176

paddle/fluid/framework/selected_rows.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ limitations under the License. */
1515
#pragma once
1616

1717
#include <algorithm>
18+
#include <memory>
19+
#include <mutex> // NOLINT
1820
#include <utility>
1921
#include <vector>
2022

@@ -46,11 +48,13 @@ class SelectedRows {
4648
SelectedRows(const std::vector<int64_t>& rows, const int64_t& height)
4749
: rows_(rows), height_(height) {
4850
value_.reset(new Tensor());
51+
auto_grown_mutex_.reset(new std::mutex);
4952
}
5053

5154
SelectedRows() {
5255
height_ = 0;
5356
value_.reset(new Tensor());
57+
auto_grown_mutex_.reset(new std::mutex);
5458
}
5559

5660
platform::Place place() const { return value_->place(); }
@@ -125,6 +129,7 @@ class SelectedRows {
125129
Vector<int64_t> rows_;
126130
std::unique_ptr<Tensor> value_{nullptr};
127131
int64_t height_;
132+
std::unique_ptr<std::mutex> auto_grown_mutex_{nullptr};
128133
};
129134

130135
/*

0 commit comments

Comments
 (0)