Skip to content

Commit f1bf133

Browse files
authored
Merge pull request #12823 from jacquesqiao/cherry-pick-rw-lock
Cherry pick rw lock
2 parents ef9029d + 3f103a7 commit f1bf133

File tree

13 files changed

+383
-185
lines changed

13 files changed

+383
-185
lines changed

paddle/fluid/framework/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,8 @@ cc_test(cow_ptr_tests SRCS details/cow_ptr_test.cc)
115115
# cc_test(channel_test SRCS channel_test.cc)
116116
cc_test(tuple_test SRCS tuple_test.cc )
117117

118+
cc_test(rw_lock_test SRCS rw_lock_test.cc)
119+
118120
# disable test temporarily.
119121
# TODO https://github.com/PaddlePaddle/Paddle/issues/11971
120122
# cc_test(concurrency_test SRCS concurrency_test.cc DEPS go_op channel_close_op channel_create_op

paddle/fluid/framework/rw_lock.h

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
17+
#include <pthread.h>
18+
19+
#include "paddle/fluid/platform/enforce.h"
20+
21+
namespace paddle {
22+
namespace framework {
23+
24+
struct RWLock {
25+
RWLock() { pthread_rwlock_init(&lock_, nullptr); }
26+
27+
~RWLock() { pthread_rwlock_destroy(&lock_); }
28+
29+
void RDLock() {
30+
PADDLE_ENFORCE_EQ(pthread_rwlock_rdlock(&lock_), 0,
31+
"acquire read lock failed");
32+
}
33+
34+
void WRLock() {
35+
PADDLE_ENFORCE_EQ(pthread_rwlock_wrlock(&lock_), 0,
36+
"acquire write lock failed");
37+
}
38+
39+
void UNLock() {
40+
PADDLE_ENFORCE_EQ(pthread_rwlock_unlock(&lock_), 0, "unlock failed");
41+
}
42+
43+
private:
44+
pthread_rwlock_t lock_;
45+
};
46+
47+
} // namespace framework
48+
} // namespace paddle
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/framework/rw_lock.h"
16+
#include <gtest/gtest.h>
17+
#include <chrono> // NOLINT
18+
#include <thread> // NOLINT
19+
#include <vector>
20+
21+
namespace f = paddle::framework;
22+
23+
void f1(f::RWLock *lock) {
24+
lock->RDLock();
25+
lock->UNLock();
26+
}
27+
28+
TEST(RWLOCK, read_read) {
29+
f::RWLock lock;
30+
lock.RDLock();
31+
std::thread t1(f1, &lock);
32+
std::thread t2(f1, &lock);
33+
t1.join();
34+
t2.join();
35+
lock.UNLock();
36+
}
37+
38+
void f2(f::RWLock *lock, std::vector<int> *result) {
39+
lock->RDLock();
40+
ASSERT_EQ(result->size(), 0UL);
41+
lock->UNLock();
42+
}
43+
44+
void f3(f::RWLock *lock, std::vector<int> *result) {
45+
lock->WRLock();
46+
result->push_back(1);
47+
lock->UNLock();
48+
}
49+
50+
TEST(RWLOCK, read_write) {
51+
f::RWLock lock;
52+
std::vector<int> result;
53+
54+
lock.RDLock();
55+
std::thread t1(f2, &lock, &result);
56+
t1.join();
57+
std::thread t2(f3, &lock, &result);
58+
std::this_thread::sleep_for(std::chrono::seconds(1));
59+
ASSERT_EQ(result.size(), 0UL);
60+
lock.UNLock();
61+
t2.join();
62+
ASSERT_EQ(result.size(), 1UL);
63+
}
64+
65+
void f4(f::RWLock *lock, std::vector<int> *result) {
66+
lock->RDLock();
67+
ASSERT_EQ(result->size(), 1UL);
68+
lock->UNLock();
69+
}
70+
71+
TEST(RWLOCK, write_read) {
72+
f::RWLock lock;
73+
std::vector<int> result;
74+
75+
lock.WRLock();
76+
std::thread t1(f4, &lock, &result);
77+
std::this_thread::sleep_for(std::chrono::seconds(1));
78+
result.push_back(1);
79+
lock.UNLock();
80+
t1.join();
81+
}

paddle/fluid/framework/selected_rows.cc

Lines changed: 60 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -120,66 +120,76 @@ bool SelectedRows::HasKey(int64_t key) const {
120120
: true;
121121
}
122122

123-
std::vector<std::pair<int64_t, int64_t>> SelectedRows::Get(
124-
const std::vector<int64_t>& keys, framework::Tensor* value) const {
123+
int64_t SelectedRows::AutoGrownIndex(int64_t key, bool auto_grown) {
124+
rwlock_->RDLock();
125+
auto iter = id_to_index_.find(key);
126+
if (iter == id_to_index_.end()) {
127+
rwlock_->UNLock();
128+
if (!auto_grown) {
129+
PADDLE_THROW("key %d not found", key);
130+
}
131+
rwlock_->WRLock();
132+
auto map_size = id_to_index_.size();
133+
auto vector_size = rows_.size();
134+
if (map_size != vector_size) {
135+
rwlock_->UNLock();
136+
PADDLE_THROW(
137+
"id_to_index_ size %d should have the same size with rows_ %d",
138+
map_size, vector_size);
139+
}
140+
auto write_iter = id_to_index_.find(key);
141+
if (write_iter == id_to_index_.end()) {
142+
size_t row_num = rows_.size();
143+
if (row_num == value_->dims()[0]) {
144+
rwlock_->UNLock();
145+
PADDLE_THROW("selected rows is full, then length exceed %d", row_num);
146+
}
147+
// key logic to put a key into id_to_index_
148+
rows_.push_back(key);
149+
auto index = static_cast<int64_t>(rows_.size() - 1);
150+
id_to_index_[key] = index;
151+
rwlock_->UNLock();
152+
return index;
153+
} else {
154+
auto index = write_iter->second;
155+
rwlock_->UNLock();
156+
return index;
157+
}
158+
} else {
159+
auto index = iter->second;
160+
rwlock_->UNLock();
161+
return index;
162+
}
163+
}
164+
165+
void SelectedRows::SyncIndex() {
166+
rwlock_->WRLock();
167+
id_to_index_.clear();
168+
for (size_t i = 0; i < rows_.size(); ++i) {
169+
id_to_index_[rows_[i]] = i;
170+
}
171+
rwlock_->UNLock();
172+
}
173+
174+
void SelectedRows::Get(const framework::Tensor& ids, framework::Tensor* value,
175+
bool auto_grown) {
125176
PADDLE_ENFORCE(value->IsInitialized(),
126177
"The value tensor should be initialized.");
127-
std::vector<std::pair<int64_t, int64_t>> non_keys_pair;
128-
if (keys.empty()) {
178+
if (ids.numel() == 0) {
129179
VLOG(3) << "keys is empty, please check data!";
130180
} else {
131181
int64_t value_width = value_->numel() / value_->dims()[0];
132182
PADDLE_ENFORCE_EQ(value_width, value->numel() / value->dims()[0],
133183
"output tensor should have the same shape with table "
134184
"except the dims[0].");
135-
136-
for (size_t i = 0; i < keys.size(); ++i) {
137-
int64_t index = Index(keys[i]);
138-
if (index == -1) {
139-
non_keys_pair.push_back(
140-
std::make_pair(keys[i], static_cast<int64_t>(i)));
141-
} else {
142-
framework::VisitDataType(
143-
framework::ToDataType(value_->type()),
144-
TensorCopyVisitor(value, i * value_width, *value_.get(),
145-
index * value_width, value_width));
146-
}
185+
for (size_t i = 0; i < ids.numel(); ++i) {
186+
int64_t index = AutoGrownIndex(ids.data<int64_t>()[i], auto_grown);
187+
framework::VisitDataType(
188+
framework::ToDataType(value_->type()),
189+
TensorCopyVisitor(value, i * value_width, *value_.get(),
190+
index * value_width, value_width));
147191
}
148192
}
149-
return non_keys_pair;
150-
}
151-
152-
bool SelectedRows::Set(int64_t key, const framework::Tensor& value) {
153-
PADDLE_ENFORCE(value.IsInitialized(), "The value should be initialized.");
154-
if (value_->IsInitialized()) {
155-
PADDLE_ENFORCE_EQ(
156-
value.type(), value_->type(),
157-
"The type of the value should be same with the original value");
158-
}
159-
PADDLE_ENFORCE_EQ(value.dims()[0], static_cast<size_t>(1),
160-
"The first dim of value should be 1.");
161-
std::lock_guard<std::mutex> lock(*auto_grown_mutex_.get());
162-
auto index = Index(key);
163-
bool is_new_key = false;
164-
if (index == -1) {
165-
rows_.push_back(key);
166-
index = rows_.size() - 1;
167-
is_new_key = true;
168-
// whether need to resize the table
169-
if (static_cast<int64_t>(rows_.size()) > value_->dims()[0]) {
170-
auto dims = value_->dims();
171-
dims[0] = (dims[0] + 1) << 1;
172-
framework::VisitDataType(framework::ToDataType(value.type()),
173-
ReAllocateVisitor(dims, value_.get()));
174-
}
175-
}
176-
177-
framework::VisitDataType(
178-
framework::ToDataType(value.type()),
179-
TensorCopyVisitor(value_.get(),
180-
index * value_->numel() / value_->dims()[0], value,
181-
static_cast<int64_t>(0), value.numel()));
182-
return is_new_key;
183193
}
184194

185195
} // namespace framework

paddle/fluid/framework/selected_rows.h

Lines changed: 35 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,12 @@ limitations under the License. */
1717
#include <algorithm>
1818
#include <memory>
1919
#include <mutex> // NOLINT
20+
#include <unordered_map>
2021
#include <utility>
2122
#include <vector>
2223

2324
#include "paddle/fluid/framework/lod_tensor.h"
25+
#include "paddle/fluid/framework/rw_lock.h"
2426
#include "paddle/fluid/framework/tensor.h"
2527
#include "paddle/fluid/memory/memcpy.h"
2628

@@ -48,13 +50,13 @@ class SelectedRows {
4850
SelectedRows(const std::vector<int64_t>& rows, const int64_t& height)
4951
: rows_(rows), height_(height) {
5052
value_.reset(new Tensor());
51-
auto_grown_mutex_.reset(new std::mutex);
53+
rwlock_.reset(new RWLock);
5254
}
5355

5456
SelectedRows() {
5557
height_ = 0;
5658
value_.reset(new Tensor());
57-
auto_grown_mutex_.reset(new std::mutex);
59+
rwlock_.reset(new RWLock);
5860
}
5961

6062
platform::Place place() const { return value_->place(); }
@@ -74,47 +76,51 @@ class SelectedRows {
7476
void set_rows(const Vector<int64_t>& rows) { rows_ = rows; }
7577

7678
/*
77-
* @brief wheter has the specified key in the table.
79+
* @brief Get the index of key in rows
80+
*
81+
* @return -1 if the key does not exists.
82+
*/
83+
int64_t Index(int64_t key) const {
84+
auto it = std::find(rows_.begin(), rows_.end(), key);
85+
if (it == rows_.end()) {
86+
PADDLE_THROW("id %s not in table", key);
87+
}
88+
return static_cast<int64_t>(std::distance(rows_.begin(), it));
89+
}
90+
91+
/*
92+
* @brief whether has the specified key in the table.
7893
*
7994
* @return true if the key is exists.
8095
*/
8196
bool HasKey(int64_t key) const;
8297

8398
/*
84-
* @brief Get value by the key list, if the
99+
* @brief Get value by the key list.
100+
* Note!!! this interface is only used when selected_rows is used as
101+
* parameters
102+
* for distribute lookup table.
85103
*
86104
* @return a list of pair which contains the non-exists key and the index in
87105
* the value
88106
*/
89-
std::vector<std::pair<int64_t, int64_t>> Get(const std::vector<int64_t>& keys,
90-
framework::Tensor* value) const;
107+
void Get(const framework::Tensor& ids, framework::Tensor* value,
108+
bool auto_grown = false);
91109

92110
/*
93-
* @brief Set a key-value pair into the table.
94-
* This function will double the value memory if it's not engouth.
111+
* @brief Get the index of the key from id_to_index_ map. If the key not
112+
* exist,
113+
* add the key into id_to_index_.
95114
*
96-
* @note:
97-
* 1. The first dim of the value should be 1
98-
* 2. The value should be initialized and the data type
99-
* should be the same with the table.
100-
*
101-
* @return true if the key is a new one, otherwise false
115+
* Note!!! this interface is only used when selected_rows is used as
116+
* parameters
117+
* for distribute lookup table.
102118
*
119+
* @return index of the key.
103120
*/
104-
bool Set(int64_t key, const Tensor& value);
121+
int64_t AutoGrownIndex(int64_t key, bool auto_grown);
105122

106-
/*
107-
* @brief Get the index of key in rows
108-
*
109-
* @return -1 if the key does not exists.
110-
*/
111-
int64_t Index(int64_t key) const {
112-
auto it = std::find(rows_.begin(), rows_.end(), key);
113-
if (it == rows_.end()) {
114-
return static_cast<int64_t>(-1);
115-
}
116-
return static_cast<int64_t>(std::distance(rows_.begin(), it));
117-
}
123+
void SyncIndex();
118124

119125
DDim GetCompleteDims() const {
120126
std::vector<int64_t> dims = vectorize(value_->dims());
@@ -127,9 +133,10 @@ class SelectedRows {
127133
// SelectedRows are simply concated when adding together. Until a
128134
// SelectedRows add a Tensor, will the duplicate rows be handled.
129135
Vector<int64_t> rows_;
136+
std::unordered_map<int64_t, int64_t> id_to_index_;
130137
std::unique_ptr<Tensor> value_{nullptr};
131138
int64_t height_;
132-
std::unique_ptr<std::mutex> auto_grown_mutex_{nullptr};
139+
std::unique_ptr<RWLock> rwlock_{nullptr};
133140
};
134141

135142
/*

0 commit comments

Comments
 (0)