Skip to content

Commit 42470f1

Browse files
committed
test=develop
1 parent 0fca168 commit 42470f1

File tree

3 files changed

+50
-54
lines changed

3 files changed

+50
-54
lines changed

paddle/fluid/framework/selected_rows.cc

Lines changed: 0 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -140,58 +140,6 @@ bool SelectedRows::HasKey(int64_t key) const {
140140
: true;
141141
}
142142

143-
int64_t SelectedRows::AutoGrownIndex(int64_t key, bool auto_grown,
144-
bool is_test) {
145-
if (is_test) {
146-
auto iter = id_to_index_.find(key);
147-
if (iter == id_to_index_.end()) {
148-
return -1;
149-
} else {
150-
return iter->second;
151-
}
152-
}
153-
154-
rwlock_->RDLock();
155-
auto iter = id_to_index_.find(key);
156-
if (iter == id_to_index_.end()) {
157-
rwlock_->UNLock();
158-
if (!auto_grown) {
159-
PADDLE_THROW("key %d not found", key);
160-
}
161-
rwlock_->WRLock();
162-
auto map_size = id_to_index_.size();
163-
auto vector_size = rows_.size();
164-
if (map_size != vector_size) {
165-
rwlock_->UNLock();
166-
PADDLE_THROW(
167-
"id_to_index_ size %d should have the same size with rows_ %d",
168-
map_size, vector_size);
169-
}
170-
auto write_iter = id_to_index_.find(key);
171-
if (write_iter == id_to_index_.end()) {
172-
int row_num = rows_.size();
173-
if (row_num == value_->dims()[0]) {
174-
rwlock_->UNLock();
175-
PADDLE_THROW("selected rows is full, then length exceed %d", row_num);
176-
}
177-
// key logic to put a key into id_to_index_
178-
rows_.push_back(key);
179-
auto index = static_cast<int64_t>(rows_.size() - 1);
180-
id_to_index_[key] = index;
181-
rwlock_->UNLock();
182-
return index;
183-
} else {
184-
auto index = write_iter->second;
185-
rwlock_->UNLock();
186-
return index;
187-
}
188-
} else {
189-
auto index = iter->second;
190-
rwlock_->UNLock();
191-
return index;
192-
}
193-
}
194-
195143
void SelectedRows::SyncIndex() {
196144
rwlock_->WRLock();
197145
id_to_index_.clear();

paddle/fluid/framework/selected_rows.h

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,55 @@ class SelectedRows {
118118
*
119119
* @return index of the key.
120120
*/
121-
int64_t AutoGrownIndex(int64_t key, bool auto_grown, bool is_test = false);
121+
int64_t AutoGrownIndex(int64_t key, bool auto_grown, bool is_test = false) {
122+
if (is_test) {
123+
auto iter = id_to_index_.find(key);
124+
if (iter == id_to_index_.end()) {
125+
return -1;
126+
} else {
127+
return iter->second;
128+
}
129+
}
130+
rwlock_->RDLock();
131+
auto iter = id_to_index_.find(key);
132+
if (iter == id_to_index_.end()) {
133+
rwlock_->UNLock();
134+
if (!auto_grown) {
135+
PADDLE_THROW("key %d not found", key);
136+
}
137+
rwlock_->WRLock();
138+
auto map_size = id_to_index_.size();
139+
auto vector_size = rows_.size();
140+
if (map_size != vector_size) {
141+
rwlock_->UNLock();
142+
PADDLE_THROW(
143+
"id_to_index_ size %d should have the same size with rows_ %d",
144+
map_size, vector_size);
145+
}
146+
auto write_iter = id_to_index_.find(key);
147+
if (write_iter == id_to_index_.end()) {
148+
int row_num = rows_.size();
149+
if (row_num == value_->dims()[0]) {
150+
rwlock_->UNLock();
151+
PADDLE_THROW("selected rows is full, then length exceed %d", row_num);
152+
}
153+
// key logic to put a key into id_to_index_
154+
rows_.push_back(key);
155+
auto index = static_cast<int64_t>(rows_.size() - 1);
156+
id_to_index_[key] = index;
157+
rwlock_->UNLock();
158+
return index;
159+
} else {
160+
auto index = write_iter->second;
161+
rwlock_->UNLock();
162+
return index;
163+
}
164+
} else {
165+
auto index = iter->second;
166+
rwlock_->UNLock();
167+
return index;
168+
}
169+
}
122170

123171
void SyncIndex();
124172
/*

paddle/fluid/operators/math/matrix_bit_code.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ void MatrixBitCodeFunctor<T>::MulGradWeight(const framework::LoDTensor& tmat,
142142

143143
for (size_t k = 0; k < input_width; ++k) {
144144
int64_t row_index =
145-
weight->AutoGrownIndex(static_cast<int64_t>(index), false);
145+
weight->AutoGrownIndex(static_cast<int64_t>(index), false, true);
146146

147147
weight_value[row_index * weight_width + k] +=
148148
tmat_value[i * tmat_width + j] * input_value[input_width * i + k];

0 commit comments

Comments
 (0)