Skip to content

Commit 1444090

Browse files
authored
[Cherry-pick] Support diff dataset tensor place in single process dataloader (#33470) (#33487)
Support diff dataset tensor place in single process dataloader cherry-pick of #33470
1 parent f57ae4d commit 1444090

File tree

3 files changed

+56
-9
lines changed

3 files changed

+56
-9
lines changed

paddle/fluid/operators/reader/buffered_reader.cc

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,6 @@ BufferedReader::BufferedReader(
6868
stream_ = platform::NpuStreamResourcePool::Instance().New(dev_idx);
6969
}
7070
#endif
71-
is_same_place_ = false;
7271
cpu_buffer_.resize(buffer_size);
7372
cuda_buffer_.resize(buffer_size);
7473
npu_buffer_.resize(buffer_size);
@@ -116,7 +115,7 @@ void BufferedReader::ReadAsync(size_t i) {
116115
std::vector<void *> cuda_pinned_ptrs;
117116
cuda_pinned_ptrs.reserve(cpu.size());
118117
platform::RecordEvent record_event("BufferedReader:MemoryCopy");
119-
// NODE(chenwehiang): When we use CUDAPinned Memory, we need call
118+
// NODE(chenweihang): When we use CUDAPinned Memory, we need call
120119
// cudaHostAlloc, that is a CUDA API, calling CUDA API need load
121120
// cuda lib into device, it will cost hundreds of MB of GPU memory.
122121
// If we don't set Device here, which will use CUDAPlace(0) default.
@@ -126,18 +125,21 @@ void BufferedReader::ReadAsync(size_t i) {
126125
if (platform::is_cpu_place(cpu[i].place())) {
127126
cuda[i].Resize(cpu[i].dims());
128127
cuda[i].set_layout(cpu[i].layout());
129-
cuda_pinned_ptrs.emplace_back(
130-
cuda[i].mutable_data(cuda_pinned_place, cpu[i].type()));
128+
cuda_pinned_ptrs[i] =
129+
cuda[i].mutable_data(cuda_pinned_place, cpu[i].type());
131130
auto size =
132131
cpu[i].numel() * paddle::framework::SizeOfType(cpu[i].type());
133132

134133
memory::Copy(cuda_pinned_place, cuda_pinned_ptrs[i],
135134
BOOST_GET_CONST(platform::CPUPlace, cpu[i].place()),
136135
cpu[i].data<void>(), size);
136+
137137
cuda[i].set_lod(cpu[i].lod());
138138
} else {
139-
// we set same place flag & use cpu[i] directly
140-
is_same_place_ = true;
139+
// Here the cpu[i]'s place may be CUDAPlace, CUDAPinnedPlace, or
140+
// others, we don't copy the memory of it to CUDAPinnedPlace, but
141+
// we should share tensor data to cuda[i]
142+
cuda[i].ShareDataWith(cpu[i]);
141143
}
142144
}
143145
} else {
@@ -296,9 +298,9 @@ void BufferedReader::ReadNextImpl(std::vector<framework::LoDTensor> *out) {
296298
return;
297299
}
298300

299-
if (platform::is_gpu_place(place_) && !is_same_place_) {
301+
if (platform::is_gpu_place(place_)) {
300302
*out = std::move(cuda_buffer_[i]);
301-
} else if (platform::is_npu_place(place_) && !is_same_place_) {
303+
} else if (platform::is_npu_place(place_)) {
302304
*out = std::move(npu_buffer_[i]);
303305
} else {
304306
*out = std::move(cpu_buffer_[i]);

paddle/fluid/operators/reader/buffered_reader.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,6 @@ class BufferedReader : public framework::DecoratedReader {
6767
// buffer, just read async and create futures as buffer size. However, to
6868
// malloc tensors every time is extremely slow. Here we store all data in
6969
// buffers and prevent alloc every time.
70-
bool is_same_place_;
7170
std::vector<TensorVec> cpu_buffer_;
7271
std::vector<TensorVec> cuda_buffer_;
7372
std::vector<TensorVec> npu_buffer_;

python/paddle/fluid/tests/unittests/test_dataloader_dataset.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,12 @@
1414

1515
from __future__ import division
1616

17+
import sys
1718
import unittest
1819
import numpy as np
1920

21+
import paddle
22+
import paddle.vision.transforms as transforms
2023
import paddle.fluid as fluid
2124
from paddle.io import *
2225

@@ -37,5 +40,48 @@ def test_main(self):
3740
pass
3841

3942

43+
class TestDatasetWithDiffOutputPlace(unittest.TestCase):
44+
def get_dataloader(self, num_workers):
45+
dataset = paddle.vision.datasets.MNIST(
46+
mode='test', transform=transforms.ToTensor())
47+
loader = paddle.io.DataLoader(
48+
dataset, batch_size=32, num_workers=num_workers, shuffle=True)
49+
return loader
50+
51+
def run_check_on_cpu(self):
52+
paddle.set_device('cpu')
53+
loader = self.get_dataloader(0)
54+
for image, label in loader:
55+
self.assertTrue(image.place.is_cpu_place())
56+
self.assertTrue(label.place.is_cpu_place())
57+
break
58+
59+
def test_single_process(self):
60+
self.run_check_on_cpu()
61+
if paddle.is_compiled_with_cuda():
62+
# Get (image, label) tuple from MNIST dataset
63+
# - the image is on CUDAPlace, label is on CPUPlace
64+
paddle.set_device('gpu')
65+
loader = self.get_dataloader(0)
66+
for image, label in loader:
67+
self.assertTrue(image.place.is_gpu_place())
68+
self.assertTrue(label.place.is_cuda_pinned_place())
69+
break
70+
71+
def test_multi_process(self):
72+
# DataLoader with multi-process mode is not supported on MacOs and Windows currently
73+
if sys.platform != 'darwin' and sys.platform != 'win32':
74+
self.run_check_on_cpu()
75+
if paddle.is_compiled_with_cuda():
76+
# Get (image, label) tuple from MNIST dataset
77+
# - the image and label are on CPUPlace
78+
paddle.set_device('gpu')
79+
loader = self.get_dataloader(1)
80+
for image, label in loader:
81+
self.assertTrue(image.place.is_cuda_pinned_place())
82+
self.assertTrue(label.place.is_cuda_pinned_place())
83+
break
84+
85+
4086
if __name__ == '__main__':
4187
unittest.main()

0 commit comments

Comments
 (0)