Skip to content

Commit 5046869

Browse files
authored
Merge pull request #8287 from tonyyang-svail/operator_set_device
Correctly handle cuda place for operators
2 parents 7757a8a + 40c7972 commit 5046869

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+214
-114
lines changed

paddle/fluid/framework/op_registry_test.cc

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,10 @@ namespace framework {
2525
class CosineOp : public OperatorBase {
2626
public:
2727
using OperatorBase::OperatorBase;
28-
void Run(const Scope& scope, const platform::Place& place) const override {}
28+
29+
private:
30+
void RunImpl(const Scope& scope,
31+
const platform::Place& place) const override {}
2932
};
3033

3134
class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
@@ -44,7 +47,10 @@ class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
4447
class MyTestOp : public OperatorBase {
4548
public:
4649
using OperatorBase::OperatorBase;
47-
void Run(const Scope& scope, const platform::Place& place) const override {}
50+
51+
private:
52+
void RunImpl(const Scope& scope,
53+
const platform::Place& place) const override {}
4854
};
4955

5056
class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {

paddle/fluid/framework/operator.cc

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,18 @@ static LoD GetLoD(const Scope& scope, const std::string& name) {
6464
}
6565
}
6666

67+
void OperatorBase::Run(const Scope& scope, const platform::Place& place) {
68+
if (platform::is_gpu_place(place)) {
69+
#ifndef PADDLE_WITH_CUDA
70+
PADDLE_THROW("Cannot run operator on place %s", place);
71+
#else
72+
auto dev_id = boost::get<platform::CUDAPlace>(place).device;
73+
platform::SetDeviceId(dev_id);
74+
#endif
75+
}
76+
RunImpl(scope, place);
77+
}
78+
6779
std::string OperatorBase::Input(const std::string& name) const {
6880
auto& ins = Inputs(name);
6981
PADDLE_ENFORCE_LE(ins.size(), 1UL,
@@ -479,8 +491,8 @@ class RuntimeInferShapeContext : public InferShapeContext {
479491
const Scope& scope_;
480492
};
481493

482-
void OperatorWithKernel::Run(const Scope& scope,
483-
const platform::Place& place) const {
494+
void OperatorWithKernel::RunImpl(const Scope& scope,
495+
const platform::Place& place) const {
484496
RuntimeInferShapeContext infer_shape_ctx(*this, scope);
485497
this->InferShape(&infer_shape_ctx);
486498
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();

paddle/fluid/framework/operator.h

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,9 @@ class OperatorBase {
8989

9090
std::string DebugString() const { return DebugStringEx(nullptr); }
9191

92-
/// Net will call this function to Run an op.
93-
virtual void Run(const Scope& scope, const platform::Place& place) const = 0;
92+
/// Net will call this interface function to Run an op.
93+
// The implementation should be written at RunImpl
94+
void Run(const Scope& scope, const platform::Place& place);
9495

9596
// FIXME(typhoonzero): this is only used for recv_op to stop event_loop.
9697
virtual void Stop() {}
@@ -144,6 +145,8 @@ class OperatorBase {
144145
private:
145146
void GenerateTemporaryNames();
146147
void CheckAllInputOutputSet() const;
148+
virtual void RunImpl(const Scope& scope,
149+
const platform::Place& place) const = 0;
147150
};
148151

149152
// Macro for define a clone method.
@@ -168,10 +171,13 @@ class OperatorBase {
168171
class NOP : public OperatorBase {
169172
public:
170173
using OperatorBase::OperatorBase;
171-
void Run(const Scope& scope, const platform::Place& place) const override {}
172174
std::unique_ptr<OperatorBase> Clone() const override {
173175
return std::unique_ptr<OperatorBase>(new NOP(*this));
174176
}
177+
178+
private:
179+
void RunImpl(const Scope& scope,
180+
const platform::Place& place) const override {}
175181
};
176182

177183
class ExecutionContext {
@@ -363,8 +369,6 @@ class OperatorWithKernel : public OperatorBase {
363369
const VariableNameMap& outputs, const AttributeMap& attrs)
364370
: OperatorBase(type, inputs, outputs, attrs) {}
365371

366-
void Run(const Scope& scope, const platform::Place& place) const final;
367-
368372
static std::unordered_map<std::string /* op_type */, OpKernelMap>&
369373
AllOpKernels() {
370374
static std::unordered_map<std::string, OpKernelMap> g_all_op_kernels;
@@ -393,6 +397,7 @@ class OperatorWithKernel : public OperatorBase {
393397
// indicate kernel DataType by input data. Defaultly all input data must be
394398
// same.
395399
proto::DataType IndicateDataType(const ExecutionContext& ctx) const;
400+
void RunImpl(const Scope& scope, const platform::Place& place) const final;
396401
};
397402

398403
extern bool OpSupportGPU(const std::string& op_type);

paddle/fluid/framework/operator_test.cc

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,10 @@ class OpWithoutKernelTest : public OperatorBase {
2828
OpWithoutKernelTest(const std::string& type, const VariableNameMap& inputs,
2929
const VariableNameMap& outputs, const AttributeMap& attrs)
3030
: OperatorBase(type, inputs, outputs, attrs), x(1) {}
31-
void Run(const Scope& scope, const platform::Place& place) const override {
31+
32+
private:
33+
void RunImpl(const Scope& scope,
34+
const platform::Place& place) const override {
3235
++op_run_num;
3336
ASSERT_EQ(static_cast<int>(inputs_.size()), 1);
3437
ASSERT_EQ(static_cast<int>(outputs_.size()), 1);
@@ -259,8 +262,10 @@ class OperatorClone : public paddle::framework::OperatorBase {
259262
const paddle::framework::VariableNameMap& outputs,
260263
const paddle::framework::AttributeMap& attrs)
261264
: OperatorBase(type, inputs, outputs, attrs) {}
262-
void Run(const paddle::framework::Scope& scope,
263-
const paddle::platform::Place& place) const override {}
265+
266+
private:
267+
void RunImpl(const paddle::framework::Scope& scope,
268+
const paddle::platform::Place& place) const override {}
264269
};
265270

266271
TEST(Operator, Clone) {

paddle/fluid/operators/array_to_lod_tensor_op.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,10 @@ class ArrayToLoDTensorOp : public framework::OperatorBase {
3131
const framework::VariableNameMap &outputs,
3232
const framework::AttributeMap &attrs)
3333
: OperatorBase(type, inputs, outputs, attrs) {}
34-
void Run(const framework::Scope &scope,
35-
const platform::Place &dev_place) const override {
34+
35+
private:
36+
void RunImpl(const framework::Scope &scope,
37+
const platform::Place &dev_place) const override {
3638
auto &x = scope.FindVar(Input("X"))->Get<framework::LoDTensorArray>();
3739
auto &rank_table =
3840
scope.FindVar(Input("RankTable"))->Get<framework::LoDRankTable>();

paddle/fluid/operators/assign_op.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,10 @@ class AssignOp : public framework::OperatorBase {
7171
const framework::VariableNameMap &outputs,
7272
const framework::AttributeMap &attrs)
7373
: OperatorBase(type, inputs, outputs, attrs) {}
74-
void Run(const framework::Scope &scope,
75-
const platform::Place &place) const override {
74+
75+
private:
76+
void RunImpl(const framework::Scope &scope,
77+
const platform::Place &place) const override {
7678
auto *x = scope.FindVar(Input("X"));
7779
if (x == nullptr) {
7880
return;

paddle/fluid/operators/beam_search_decode_op.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,10 @@ class BeamSearchDecodeOp : public framework::OperatorBase {
5555
const framework::VariableNameMap& outputs,
5656
const framework::AttributeMap& attrs)
5757
: OperatorBase(type, inputs, outputs, attrs) {}
58-
void Run(const framework::Scope& scope,
59-
const platform::Place& dev_place) const override {
58+
59+
private:
60+
void RunImpl(const framework::Scope& scope,
61+
const platform::Place& dev_place) const override {
6062
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
6163
auto& dev_ctx = *pool.Get(dev_place);
6264

paddle/fluid/operators/beam_search_op.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -204,8 +204,9 @@ class BeamSearchOp : public framework::OperatorBase {
204204
PADDLE_THROW("Not Implemented");
205205
}
206206

207-
void Run(const framework::Scope& scope,
208-
const platform::Place& dev_place) const override {
207+
private:
208+
void RunImpl(const framework::Scope& scope,
209+
const platform::Place& dev_place) const override {
209210
auto ids_var = scope.FindVar(Input("ids"));
210211
auto scores_var = scope.FindVar(Input("scores"));
211212
auto pre_ids_var = scope.FindVar(Input("pre_ids"));

paddle/fluid/operators/cond_op.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ void CondOp::MergeDataFromSubnet(const framework::Scope& scope,
193193
}
194194
}
195195

196-
void CondOp::Run(const Scope& scope, const platform::Place& place) const {
196+
void CondOp::RunImpl(const Scope& scope, const platform::Place& place) const {
197197
// get device context from pool
198198
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
199199
auto& dev_ctx = *pool.Get(place);

paddle/fluid/operators/cond_op.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,9 @@ class CondOp : public framework::OperatorBase {
7777
sub_net_op_[FALSE_BRANCH] = std::move(net);
7878
}
7979

80-
void Run(const framework::Scope& scope,
81-
const platform::Place& place) const override;
80+
private:
81+
void RunImpl(const framework::Scope& scope,
82+
const platform::Place& place) const override;
8283

8384
private:
8485
const int TRUE_BRANCH = 0;

0 commit comments

Comments
 (0)