Skip to content

Commit 6005818

Browse files
authored
Merge pull request #13945 from sneaxiy/unify_mixed_vector_api
Unify API of mixed_vector in GPU and CPU
2 parents bcc9126 + b1fd62f commit 6005818

File tree

1 file changed

+27
-0
lines changed

1 file changed

+27
-0
lines changed

paddle/fluid/framework/mixed_vector.h

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -542,6 +542,33 @@ class CPUVector : public std::vector<T, std::allocator<T>> {
542542
this->reserve(this->size() + size_t(end - begin));
543543
this->insert(this->end(), begin, end);
544544
}
545+
546+
const T *CUDAData(platform::Place place) const {
547+
PADDLE_THROW(
548+
"Vector::CUDAData() method is not supported in CPU-only version");
549+
}
550+
551+
T *CUDAMutableData(platform::Place place) {
552+
PADDLE_THROW(
553+
"Vector::CUDAMutableData() method is not supported in CPU-only "
554+
"version");
555+
}
556+
557+
const T *Data(platform::Place place) const {
558+
PADDLE_ENFORCE(
559+
platform::is_cpu_place(place),
560+
"Vector::Data() method is not supported when not in CPUPlace");
561+
return this->data();
562+
}
563+
564+
T *MutableData(platform::Place place) {
565+
PADDLE_ENFORCE(
566+
platform::is_cpu_place(place),
567+
"Vector::MutableData() method is not supported when not in CPUPlace");
568+
return this->data();
569+
}
570+
571+
const void *Handle() const { return static_cast<const void *>(this); }
545572
};
546573

547574
template <typename T>

0 commit comments

Comments
 (0)