@@ -18,8 +18,8 @@ namespace paddle {
18
18
namespace framework {
19
19
20
20
struct ReAllocateVisitor {
21
- ReAllocateVisitor (framework::Tensor* tensor, const framework::DDim& dims )
22
- : tensor_(tensor ), dims_(dims ) {}
21
+ ReAllocateVisitor (const framework::DDim& dims, framework::Tensor* tensor )
22
+ : dims_(dims ), tensor_(tensor ) {}
23
23
24
24
template <typename T>
25
25
void operator ()() const {
@@ -34,8 +34,8 @@ struct ReAllocateVisitor {
34
34
tensor_->ShareDataWith (cpu_tensor);
35
35
}
36
36
37
- framework::Tensor* tensor_;
38
37
framework::DDim dims_;
38
+ framework::Tensor* tensor_;
39
39
};
40
40
41
41
struct TensorCopyVisitor {
@@ -158,6 +158,7 @@ bool SelectedRows::Set(int64_t key, const framework::Tensor& value) {
158
158
}
159
159
PADDLE_ENFORCE_EQ (value.dims ()[0 ], static_cast <size_t >(1 ),
160
160
" The first dim of value should be 1." );
161
+ std::lock_guard<std::mutex> lock (*auto_grown_mutex_.get ());
161
162
auto index = Index (key);
162
163
bool is_new_key = false ;
163
164
if (index == -1 ) {
@@ -169,7 +170,7 @@ bool SelectedRows::Set(int64_t key, const framework::Tensor& value) {
169
170
auto dims = value_->dims ();
170
171
dims[0 ] = (dims[0 ] + 1 ) << 1 ;
171
172
framework::VisitDataType (framework::ToDataType (value.type ()),
172
- ReAllocateVisitor (value_.get (), dims ));
173
+ ReAllocateVisitor (dims, value_.get ()));
173
174
}
174
175
}
175
176
0 commit comments