@@ -286,62 +286,86 @@ void DeserializeFromStream(std::istream &is, LoDTensor *tensor,
286
286
DeserializeFromStream (is, static_cast <Tensor *>(tensor), dev_ctx);
287
287
}
288
288
289
- // TODO(tonyyang-svail): make this function support LoD
290
289
std::vector<LoDTensor> LoDTensor::SplitLoDTensor (
291
290
const std::vector<platform::Place> places) const {
292
291
check_memory_size ();
293
- PADDLE_ENFORCE (lod ().empty (), " Disable parallel lod for now" );
294
- size_t result_size = std::min (static_cast <size_t >(dims ()[0 ]), places.size ());
295
- size_t remainder = dims ()[0 ] % places.size ();
292
+ int batch_size =
293
+ lod ().empty () ? dims ()[0 ] : static_cast <int >(lod ()[0 ].size ()) - 1 ;
294
+ size_t result_size = std::min (static_cast <size_t >(batch_size), places.size ());
295
+ size_t remainder = batch_size % places.size ();
296
296
297
297
std::vector<LoDTensor> results;
298
298
results.reserve (result_size);
299
299
300
- int step_width = static_cast <int >(dims ()[ 0 ] / result_size);
300
+ int step_width = static_cast <int >(batch_size / result_size);
301
301
for (size_t i = 0 ; i < result_size; ++i) {
302
302
int begin = static_cast <int >(i * step_width);
303
303
int end = static_cast <int >((i + 1 ) * step_width);
304
304
if (i + 1 == places.size ()) { // last
305
305
end += remainder;
306
306
}
307
307
308
- auto src = Slice (begin, end);
309
- auto &dst_place = places[i];
310
308
LoDTensor dst;
311
- if (!(dst_place == place ())) {
309
+ if (lod ().empty ()) {
310
+ auto src = Slice (begin, end);
311
+ auto &dst_place = places[i];
312
312
framework::Copy (src, dst_place, &dst);
313
- } else { // It is no need to copy if src_place and dst_place are same.
314
- dst.ShareDataWith (src);
313
+ } else {
314
+ auto lod_and_offset = GetSubLoDAndAbsoluteOffset (lod (), begin, end, 0 );
315
+
316
+ auto &offset = lod_and_offset.second ;
317
+ auto src = Slice (offset.first , offset.second );
318
+ auto &dst_place = places[i];
319
+ framework::Copy (src, dst_place, &dst);
320
+
321
+ LoD my_lod;
322
+ for (auto &l : lod_and_offset.first ) {
323
+ std::vector<size_t > v{0 };
324
+ for (auto &ll : l) {
325
+ v.push_back (ll + v.back ());
326
+ }
327
+ my_lod.emplace_back (v);
328
+ }
329
+ dst.set_lod (my_lod);
315
330
}
316
331
results.emplace_back (dst);
317
332
}
318
333
319
334
return results;
320
335
}
321
336
322
- // TODO(tonyyang-svail): make this function support LoD
323
337
void LoDTensor::MergeLoDTensor (
324
338
const std::vector<const LoDTensor *> &lod_tensors,
325
339
platform::Place dst_place) {
326
340
PADDLE_ENFORCE (!lod_tensors.empty ());
341
+
327
342
framework::DDim new_dim = lod_tensors[0 ]->dims ();
328
343
std::type_index new_type = lod_tensors[0 ]->type ();
329
- auto new_layout = lod_tensors[0 ]->layout ();
330
- int64_t new_height = 0 ;
331
- for (auto *lod : lod_tensors) {
332
- new_height += lod->dims ()[0 ];
333
- for (int i = 1 ; i < new_dim.size (); ++i) {
334
- PADDLE_ENFORCE_EQ (new_dim[i], lod->dims ()[i]);
344
+ framework::DataLayout new_layout = lod_tensors[0 ]->layout ();
345
+ LoD new_lod = lod_tensors[0 ]->lod ();
346
+ for (size_t i = 1 ; i < lod_tensors.size (); ++i) {
347
+ auto *t = lod_tensors[i];
348
+ PADDLE_ENFORCE_EQ (new_type.hash_code (), t->type ().hash_code ());
349
+ PADDLE_ENFORCE_EQ (new_layout, t->layout ());
350
+
351
+ PADDLE_ENFORCE_EQ (framework::product (new_dim) / new_dim[0 ],
352
+ framework::product (t->dims ()) / t->dims ()[0 ]);
353
+ new_dim[0 ] += t->dims ()[0 ];
354
+
355
+ auto &lod = t->lod ();
356
+ for (size_t j = 0 ; j < lod.size (); ++j) {
357
+ auto &sub_lod = new_lod[j];
358
+ auto &offset = sub_lod.back ();
359
+ for (size_t k = 1 ; k < lod[j].size (); ++k) {
360
+ sub_lod.push_back (lod[j][k] + offset);
361
+ }
335
362
}
336
-
337
- PADDLE_ENFORCE_EQ (new_type, lod->type ());
338
- PADDLE_ENFORCE_EQ (new_layout, lod->layout ());
339
363
}
340
- new_dim[0 ] = new_height;
341
364
Resize (new_dim);
342
365
set_layout (new_layout);
343
-
366
+ set_lod (new_lod);
344
367
mutable_data (dst_place, new_type);
368
+
345
369
int begin = 0 ;
346
370
for (auto *src : lod_tensors) {
347
371
int end = begin + src->dims ()[0 ];
0 commit comments