@@ -11,6 +11,7 @@ distributed under the License is distributed on an "AS IS" BASIS,
11
11
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
See the License for the specific language governing permissions and
13
13
limitations under the License. */
14
+ #include < paddle/fluid/operators/math/concat.h>
14
15
#include < numeric>
15
16
16
17
#include " paddle/fluid/framework/lod_rank_table.h"
@@ -24,6 +25,50 @@ namespace operators {
24
25
25
26
using LoD = framework::LoD;
26
27
28
+ class ArrayToLoDFunctor ;
29
+ template <typename DeviceContext>
30
+ struct ArrayToLoDFunctorImpl {
31
+ const ArrayToLoDFunctor *prev_functor_;
32
+ DeviceContext *dev_ctx_;
33
+
34
+ template <typename T>
35
+ void apply ();
36
+ };
37
+
38
+ struct ArrayToLoDFunctor : public boost ::static_visitor<void > {
39
+ std::vector<framework::Tensor> in;
40
+ mutable framework::Tensor *out;
41
+
42
+ template <typename Place>
43
+ void operator ()(Place place) const {
44
+ auto &pool = platform::DeviceContextPool::Instance ();
45
+ if (std::is_same<Place, platform::CPUPlace>::value) {
46
+ Apply (static_cast <platform::CPUDeviceContext *>(pool.Get (place)));
47
+ } else {
48
+ #ifdef PADDLE_WITH_CUDA
49
+ Apply (static_cast <platform::CUDADeviceContext *>(pool.Get (place)));
50
+ #else
51
+ PADDLE_THROW (" Fluid is not compiled with CUDA" );
52
+ #endif
53
+ }
54
+ }
55
+
56
+ template <typename DeviceContext>
57
+ void Apply (DeviceContext *dev_ctx) const {
58
+ ArrayToLoDFunctorImpl<DeviceContext> functor;
59
+ functor.dev_ctx_ = dev_ctx;
60
+ functor.prev_functor_ = this ;
61
+ framework::VisitDataType (framework::ToDataType (out->type ()), functor);
62
+ }
63
+ };
64
+
65
+ template <typename DeviceContext>
66
+ template <typename T>
67
+ void ArrayToLoDFunctorImpl<DeviceContext>::apply() {
68
+ math::ConcatFunctor<DeviceContext, T> func;
69
+ func (*dev_ctx_, prev_functor_->in , 0 , prev_functor_->out );
70
+ }
71
+
27
72
class ArrayToLoDTensorOp : public framework ::OperatorBase {
28
73
public:
29
74
ArrayToLoDTensorOp (const std::string &type,
@@ -47,14 +92,18 @@ class ArrayToLoDTensorOp : public framework::OperatorBase {
47
92
int rank = x[0 ].dims ().size ();
48
93
platform::Place place = x[0 ].place ();
49
94
std::type_index data_type = x[0 ].type ();
50
- framework::DDim ins_dims = framework::slice_ddim (x[0 ].dims (), 1 , rank);
51
95
int64_t batch_size = x[0 ].dims ()[0 ];
96
+ framework::DDim ins_dims = rank > 1
97
+ ? framework::slice_ddim (x[0 ].dims (), 1 , rank)
98
+ : framework::make_ddim ({0 });
52
99
for (size_t i = 1 ; i < x.size (); ++i) {
53
- PADDLE_ENFORCE_EQ (framework::slice_ddim (x[i].dims (), 1 , rank), ins_dims,
100
+ auto ins_i_dims = rank > 1 ? framework::slice_ddim (x[i].dims (), 1 , rank)
101
+ : framework::make_ddim ({0 });
102
+ PADDLE_ENFORCE_EQ (ins_i_dims, ins_dims,
54
103
" The dimension of the %zu'th element in LoDTensorArray "
55
104
" differs from previous ones." ,
56
105
i);
57
- PADDLE_ENFORCE (platform::places_are_same_class ( x[i].place (), place) ,
106
+ PADDLE_ENFORCE (x[i].place () == place,
58
107
" The place class of the %zu'th element in LoDTensorArray "
59
108
" differs from previous ones." ,
60
109
i);
@@ -82,13 +131,14 @@ class ArrayToLoDTensorOp : public framework::OperatorBase {
82
131
// Build LoDTensor `out`
83
132
framework::LoD *out_lod = out->mutable_lod ();
84
133
out_lod->clear ();
85
- size_t out_offset = 0 ;
86
134
auto prefix_lod = rank_table.coarse_lod ();
87
135
prefix_lod.emplace_back ();
88
136
auto &cur_level_lod = prefix_lod.back ();
89
137
cur_level_lod.push_back (0 );
138
+ ArrayToLoDFunctor functor;
90
139
for (size_t idx : table_item_idx) {
91
140
cur_level_lod.push_back (cur_level_lod.back () + table_items[idx].length );
141
+ PADDLE_ENFORCE_LE (table_items[idx].length , x.size ());
92
142
for (size_t x_idx = 0 ; x_idx < table_items[idx].length ; ++x_idx) {
93
143
auto lod_and_offset = framework::GetSubLoDAndAbsoluteOffset (
94
144
x[x_idx].lod (), idx, idx + 1 , 0 );
@@ -106,17 +156,11 @@ class ArrayToLoDTensorOp : public framework::OperatorBase {
106
156
if (len == 0 ) {
107
157
continue ;
108
158
}
109
- auto slice = out->Slice (out_offset, out_offset + len);
110
-
111
- platform::DeviceContextPool &pool =
112
- platform::DeviceContextPool::Instance ();
113
- auto &dev_ctx = *pool.Get (place);
114
-
115
- framework::TensorCopy (x[x_idx].Slice (start_offset, end_offset), place,
116
- dev_ctx, &slice);
117
- out_offset += len;
159
+ functor.in .emplace_back (x[x_idx].Slice (start_offset, end_offset));
118
160
}
119
161
}
162
+ functor.out = out;
163
+ platform::VisitPlace (place, functor);
120
164
out_lod->insert (out_lod->begin (), prefix_lod.begin (), prefix_lod.end ());
121
165
}
122
166
};
0 commit comments