Skip to content

Commit 037ce12

Browse files
authored
Merge pull request #11907 from reyoung/feature/use_dev_ctx_for_op
Use std::map for Place <--> DeviceContext
2 parents 71b1c39 + 2d0e559 commit 037ce12

File tree

6 files changed

+13
-32
lines changed

6 files changed

+13
-32
lines changed

paddle/fluid/framework/details/op_handle_base.cc

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -124,16 +124,10 @@ void OpHandleBase::RunAndRecordEvent(const std::function<void()> &callback) {
124124
#ifdef PADDLE_WITH_CUDA
125125
if (!events_.empty()) { // Use event
126126
std::function<void()> method = callback;
127-
// NOTE(zcd): device context must be ordered here because RecordEvent
128-
// will use a mutex to ensure the safe of multi-threads.
129-
std::map<platform::DeviceContext *, platform::Place> ordered_ctxes;
130127
for (auto &p : dev_ctxes_) {
131-
ordered_ctxes.emplace(p.second, p.first);
132-
}
133-
for (auto &p : ordered_ctxes) {
134128
method = [method, p, this]() {
135-
static_cast<platform::CUDADeviceContext *>(p.first)->RecordEvent(
136-
events_.at(boost::get<platform::CUDAPlace>(p.second).device),
129+
static_cast<platform::CUDADeviceContext *>(p.second)->RecordEvent(
130+
events_.at(boost::get<platform::CUDAPlace>(p.first).device),
137131
method);
138132
};
139133
}

paddle/fluid/framework/details/op_handle_base.h

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@
1313
// limitations under the License.
1414

1515
#pragma once
16+
#include <map>
1617
#include <string>
1718
#include <vector>
18-
1919
#include "paddle/fluid/framework/details/var_handle.h"
2020
#include "paddle/fluid/platform/device_context.h"
2121
#include "paddle/fluid/platform/macros.h"
@@ -92,9 +92,7 @@ class OpHandleBase {
9292

9393
std::vector<VarHandleBase *> inputs_;
9494
std::vector<VarHandleBase *> outputs_;
95-
std::unordered_map<platform::Place, platform::DeviceContext *,
96-
platform::PlaceHash>
97-
dev_ctxes_;
95+
std::map<platform::Place, platform::DeviceContext *> dev_ctxes_;
9896

9997
#ifdef PADDLE_WITH_CUDA
10098
std::unordered_map<int, cudaEvent_t> events_;

paddle/fluid/framework/details/reduce_and_gather.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,7 @@ struct ReduceLoDTensor {
5454
inline void GatherSelectedRows(
5555
const std::vector<const SelectedRows *> &src_selecte_rows_,
5656
const std::vector<platform::Place> &in_places,
57-
const std::unordered_map<platform::Place, platform::DeviceContext *,
58-
platform::PlaceHash> &dev_ctxes,
57+
const std::map<platform::Place, platform::DeviceContext *> &dev_ctxes,
5958
const platform::Place &out_place, SelectedRows *dst_selecte_rows) {
6059
PADDLE_ENFORCE(!src_selecte_rows_.empty());
6160

paddle/fluid/platform/device_context.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ See the License for the specific language governing permissions and
1010
limitations under the License. */
1111
#include "paddle/fluid/platform/device_context.h"
1212

13+
#include <set>
1314
#include <string>
1415
#include <unordered_set>
1516
#include <vector>
@@ -35,7 +36,7 @@ DeviceContextPool::DeviceContextPool(
3536
const std::vector<platform::Place>& places) {
3637
PADDLE_ENFORCE_GT(places.size(), 0);
3738
using PtrType = std::unique_ptr<DeviceContext>;
38-
std::unordered_set<Place, PlaceHash> set;
39+
std::set<Place> set;
3940
for (auto& p : places) {
4041
set.insert(p);
4142
}

paddle/fluid/platform/device_context.h

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,12 @@ limitations under the License. */
2727
#include <mkldnn.hpp>
2828
#endif
2929

30+
#include <map>
31+
#include "glog/logging.h"
3032
#include "paddle/fluid/platform/enforce.h"
3133
#include "paddle/fluid/platform/place.h"
3234
#include "unsupported/Eigen/CXX11/Tensor"
3335

34-
#include "glog/logging.h"
35-
3636
namespace paddle {
3737
namespace platform {
3838

@@ -201,9 +201,7 @@ class DeviceContextPool {
201201

202202
private:
203203
static DeviceContextPool* pool;
204-
std::unordered_map<const platform::Place,
205-
std::unique_ptr<platform::DeviceContext>, PlaceHash>
206-
device_contexts_;
204+
std::map<Place, std::unique_ptr<DeviceContext>> device_contexts_;
207205
DISABLE_COPY_AND_ASSIGN(DeviceContextPool);
208206
};
209207

paddle/fluid/platform/place.h

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ struct CPUPlace {
3030
// needed for variant equality comparison
3131
inline bool operator==(const CPUPlace &) const { return true; }
3232
inline bool operator!=(const CPUPlace &) const { return false; }
33+
inline bool operator<(const CPUPlace &) const { return false; }
3334
};
3435

3536
struct CUDAPlace {
@@ -42,6 +43,7 @@ struct CUDAPlace {
4243
return device == o.device;
4344
}
4445
inline bool operator!=(const CUDAPlace &o) const { return !(*this == o); }
46+
inline bool operator<(const CUDAPlace &o) const { return device < o.device; }
4547

4648
int device;
4749
};
@@ -52,6 +54,7 @@ struct CUDAPinnedPlace {
5254
// needed for variant equality comparison
5355
inline bool operator==(const CUDAPinnedPlace &) const { return true; }
5456
inline bool operator!=(const CUDAPinnedPlace &) const { return false; }
57+
inline bool operator<(const CUDAPinnedPlace &) const { return false; }
5558
};
5659

5760
struct IsCUDAPlace : public boost::static_visitor<bool> {
@@ -89,18 +92,6 @@ bool is_cuda_pinned_place(const Place &);
8992
bool places_are_same_class(const Place &, const Place &);
9093
bool is_same_place(const Place &, const Place &);
9194

92-
struct PlaceHash {
93-
std::size_t operator()(const Place &p) const {
94-
constexpr size_t num_dev_bits = 4;
95-
std::hash<int> ihash;
96-
size_t dev_id = 0;
97-
if (is_gpu_place(p)) {
98-
dev_id = boost::get<CUDAPlace>(p).device;
99-
}
100-
return ihash(dev_id << num_dev_bits | p.which());
101-
}
102-
};
103-
10495
std::ostream &operator<<(std::ostream &, const Place &);
10596

10697
template <typename Visitor>

0 commit comments

Comments
 (0)