Skip to content

Commit 91b7248

Browse files
authored
Merge pull request #5443 from reyoung/feature/InferKernelKey
Polish OpWithKernel
2 parents fcc50cb + 46610e9 commit 91b7248

25 files changed

+185
-126
lines changed

doc/design/float16.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,6 @@ After float16 class is available, some of the future items are below:
5555

5656
- Update pybind/tensor_py.h to bind c++ float16 with numpy float16.
5757

58-
- Modify `IndicateDataType()` method in `framework/operator.h` to make it compatible with float16.
58+
- Modify `GetKernelType()` method in `framework/operator.h` to make it compatible with float16.
5959

6060
- Create a type-casting operator that can convert the data type in tensor between float16 and other types.

paddle/framework/op_registry.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,7 @@ struct OpKernelRegistrarFunctor<PlaceType, false, I, KernelTypes...> {
9292

9393
void operator()(const char* op_type) const {
9494
using T = typename KERNEL_TYPE::ELEMENT_TYPE;
95-
OperatorWithKernel::OpKernelKey key(ToDataType(std::type_index(typeid(T))),
96-
PlaceType());
95+
OpKernelType key(ToDataType(std::type_index(typeid(T))), PlaceType());
9796
OperatorWithKernel::AllOpKernels()[op_type][key].reset(new KERNEL_TYPE);
9897

9998
constexpr auto size = std::tuple_size<std::tuple<KernelTypes...>>::value;

paddle/framework/operator.cc

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -254,8 +254,7 @@ std::vector<Tensor*> ExecutionContext::MultiOutput<Tensor>(
254254
return res;
255255
}
256256

257-
std::ostream& operator<<(std::ostream& os,
258-
const OperatorWithKernel::OpKernelKey& kernel_key) {
257+
std::ostream& operator<<(std::ostream& os, const OpKernelType& kernel_key) {
259258
os << "place[" << kernel_key.place_ << "]:data_type[" << kernel_key.data_type_
260259
<< "]";
261260
return os;
@@ -432,7 +431,7 @@ void OperatorWithKernel::Run(const Scope& scope,
432431

433432
// check if op[type] have kernel for kernel_key
434433
OpKernelMap& kernels = kernels_iter->second;
435-
auto kernel_key = OpKernelKey(IndicateDataType(ctx), dev_ctx);
434+
auto kernel_key = GetKernelType(ctx);
436435
auto kernel_iter = kernels.find(kernel_key);
437436

438437
if (kernel_iter == kernels.end()) {
@@ -444,6 +443,38 @@ void OperatorWithKernel::Run(const Scope& scope,
444443
// throws errors if have.
445444
dev_ctx.Finish();
446445
}
446+
OpKernelType OperatorWithKernel::GetKernelType(
447+
const ExecutionContext& ctx) const {
448+
return OpKernelType(IndicateDataType(ctx), ctx.device_context());
449+
}
450+
DataType OperatorWithKernel::IndicateDataType(
451+
const ExecutionContext& ctx) const {
452+
auto& scope = ctx.scope();
453+
int data_type = -1;
454+
for (auto& input : this->inputs_) {
455+
for (auto& ipt_name : input.second) {
456+
auto* var = scope.FindVar(ipt_name);
457+
if (var != nullptr) {
458+
const Tensor* t = nullptr;
459+
if (var->IsType<Tensor>()) {
460+
t = &var->Get<Tensor>();
461+
} else if (var->IsType<LoDTensor>()) {
462+
t = &var->Get<LoDTensor>();
463+
} else if (var->IsType<SelectedRows>()) {
464+
t = &(var->Get<SelectedRows>().value());
465+
}
466+
if (t != nullptr) {
467+
int tmp = static_cast<int>(ToDataType(t->type()));
468+
PADDLE_ENFORCE(tmp == data_type || data_type == -1,
469+
"DataType of Paddle Op %s must be the same.", Type());
470+
data_type = tmp;
471+
}
472+
}
473+
}
474+
}
475+
PADDLE_ENFORCE(data_type != -1, "DataType should be indicated by input");
476+
return static_cast<DataType>(data_type);
477+
}
447478

448479
} // namespace framework
449480
} // namespace paddle

paddle/framework/operator.h

Lines changed: 27 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -345,27 +345,10 @@ class OpKernel : public OpKernelBase {
345345
using ELEMENT_TYPE = T;
346346
};
347347

348-
class OperatorWithKernel : public OperatorBase {
349-
public:
350-
struct OpKernelKey {
351-
platform::Place place_;
352-
DataType data_type_;
353-
354-
OpKernelKey(DataType data_type, platform::Place place)
355-
: place_(place), data_type_(data_type) {}
356-
357-
OpKernelKey(DataType data_type, const platform::DeviceContext& dev_ctx)
358-
: place_(dev_ctx.GetPlace()), data_type_(data_type) {}
359-
360-
bool operator==(const OpKernelKey& o) const {
361-
return platform::places_are_same_class(place_, o.place_) &&
362-
data_type_ == o.data_type_;
363-
}
364-
};
365-
366-
struct OpKernelHash {
348+
struct OpKernelType {
349+
struct Hash {
367350
std::hash<int> hash_;
368-
size_t operator()(const OpKernelKey& key) const {
351+
size_t operator()(const OpKernelType& key) const {
369352
int place = key.place_.which();
370353
int data_type = static_cast<int>(key.data_type_);
371354
int pre_hash = data_type << NUM_PLACE_TYPE_LIMIT_IN_BIT |
@@ -374,9 +357,26 @@ class OperatorWithKernel : public OperatorBase {
374357
}
375358
};
376359

360+
platform::Place place_;
361+
DataType data_type_;
362+
363+
OpKernelType(DataType data_type, platform::Place place)
364+
: place_(place), data_type_(data_type) {}
365+
366+
OpKernelType(DataType data_type, const platform::DeviceContext& dev_ctx)
367+
: place_(dev_ctx.GetPlace()), data_type_(data_type) {}
368+
369+
bool operator==(const OpKernelType& o) const {
370+
return platform::places_are_same_class(place_, o.place_) &&
371+
data_type_ == o.data_type_;
372+
}
373+
};
374+
375+
class OperatorWithKernel : public OperatorBase {
376+
public:
377377
using OpKernelMap =
378-
std::unordered_map<OpKernelKey, std::unique_ptr<OpKernelBase>,
379-
OpKernelHash>;
378+
std::unordered_map<OpKernelType, std::unique_ptr<OpKernelBase>,
379+
OpKernelType::Hash>;
380380

381381
OperatorWithKernel(const std::string& type, const VariableNameMap& inputs,
382382
const VariableNameMap& outputs, const AttributeMap& attrs)
@@ -404,40 +404,15 @@ class OperatorWithKernel : public OperatorBase {
404404
}
405405

406406
protected:
407+
virtual OpKernelType GetKernelType(const ExecutionContext& ctx) const;
408+
409+
private:
407410
// indicate kernel DataType by input data. Defaultly all input data must be
408411
// same.
409-
virtual DataType IndicateDataType(const ExecutionContext& ctx) const {
410-
auto& scope = ctx.scope();
411-
int data_type = -1;
412-
for (auto& input : this->inputs_) {
413-
for (auto& ipt_name : input.second) {
414-
auto* var = scope.FindVar(ipt_name);
415-
if (var != nullptr) {
416-
const Tensor* t = nullptr;
417-
if (var->IsType<Tensor>()) {
418-
t = &var->Get<Tensor>();
419-
} else if (var->IsType<LoDTensor>()) {
420-
t = &var->Get<LoDTensor>();
421-
} else if (var->IsType<SelectedRows>()) {
422-
t = &(var->Get<SelectedRows>().value());
423-
}
424-
if (t != nullptr) {
425-
int tmp = static_cast<int>(ToDataType(t->type()));
426-
PADDLE_ENFORCE(tmp == data_type || data_type == -1,
427-
"DataType of Paddle Op %s must be the same.",
428-
Type());
429-
data_type = tmp;
430-
}
431-
}
432-
}
433-
}
434-
PADDLE_ENFORCE(data_type != -1, "DataType should be indicated by input");
435-
return static_cast<DataType>(data_type);
436-
}
412+
DataType IndicateDataType(const ExecutionContext& ctx) const;
437413
};
438414

439-
std::ostream& operator<<(std::ostream& os,
440-
const OperatorWithKernel::OpKernelKey& kernel_key);
415+
std::ostream& operator<<(std::ostream& os, const OpKernelType& kernel_key);
441416

442417
extern bool OpSupportGPU(const std::string& op_type);
443418

paddle/framework/operator_test.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,8 +114,8 @@ class OpWithKernelTest : public OperatorWithKernel {
114114

115115
protected:
116116
void InferShape(framework::InferShapeContext* ctx) const override {}
117-
DataType IndicateDataType(const ExecutionContext& ctx) const override {
118-
return DataType::FP32;
117+
OpKernelType GetKernelType(const ExecutionContext& ctx) const override {
118+
return OpKernelType(DataType::FP32, ctx.device_context());
119119
}
120120
};
121121

paddle/operators/accuracy_op.cc

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,11 @@ class AccuracyOp : public framework::OperatorWithKernel {
4747
}
4848

4949
protected:
50-
// IndicateDataType
51-
framework::DataType IndicateDataType(
50+
framework::OpKernelType GetKernelType(
5251
const framework::ExecutionContext &ctx) const override {
53-
return framework::ToDataType(ctx.Input<Tensor>("Out")->type());
52+
return framework::OpKernelType(
53+
framework::ToDataType(ctx.Input<Tensor>("Out")->type()),
54+
ctx.device_context());
5455
}
5556
};
5657

paddle/operators/auc_op.cc

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,11 @@ class AucOp : public framework::OperatorWithKernel {
3939
}
4040

4141
protected:
42-
// IndicateDataType
43-
framework::DataType IndicateDataType(
42+
framework::OpKernelType GetKernelType(
4443
const framework::ExecutionContext &ctx) const override {
45-
return framework::ToDataType(ctx.Input<Tensor>("Out")->type());
44+
return framework::OpKernelType(
45+
framework::ToDataType(ctx.Input<Tensor>("Out")->type()),
46+
ctx.device_context());
4647
}
4748
};
4849

paddle/operators/batch_norm_op.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -303,7 +303,8 @@ class BatchNormGradOp : public framework::OperatorWithKernel {
303303
ctx->SetOutputDim(framework::GradVarName("Bias"), {C});
304304
}
305305

306-
framework::DataType IndicateDataType(
306+
protected:
307+
framework::OpKernelType GetKernelType(
307308
const framework::ExecutionContext &ctx) const override {
308309
const auto *var = ctx.InputVar(framework::GradVarName("Y"));
309310
if (var == nullptr) {
@@ -318,7 +319,8 @@ class BatchNormGradOp : public framework::OperatorWithKernel {
318319
if (t == nullptr) {
319320
PADDLE_THROW("can't find Y@GRAD");
320321
}
321-
return framework::ToDataType(t->type());
322+
return framework::OpKernelType(framework::ToDataType(t->type()),
323+
ctx.device_context());
322324
}
323325
};
324326

paddle/operators/crf_decoding_op.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,9 +120,11 @@ class CRFDecodingOp : public framework::OperatorWithKernel {
120120
}
121121

122122
protected:
123-
framework::DataType IndicateDataType(
123+
framework::OpKernelType GetKernelType(
124124
const framework::ExecutionContext& ctx) const override {
125-
return framework::ToDataType(ctx.Input<LoDTensor>("Emission")->type());
125+
return framework::OpKernelType(
126+
framework::ToDataType(ctx.Input<LoDTensor>("Emission")->type()),
127+
ctx.device_context());
126128
}
127129
};
128130
} // namespace operators

paddle/operators/cross_entropy_op.cc

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,11 @@ class CrossEntropyOp : public framework::OperatorWithKernel {
5151
protected:
5252
// Explicitly set that the data type of computation kernel of cross_entropy
5353
// is determined by its input "X".
54-
framework::DataType IndicateDataType(
54+
framework::OpKernelType GetKernelType(
5555
const framework::ExecutionContext& ctx) const override {
56-
return framework::ToDataType(ctx.Input<Tensor>("X")->type());
56+
return framework::OpKernelType(
57+
framework::ToDataType(ctx.Input<Tensor>("X")->type()),
58+
ctx.device_context());
5759
}
5860
};
5961

@@ -98,9 +100,11 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel {
98100
protected:
99101
// Explicitly set that the data type of computation kernel of cross_entropy
100102
// is determined by its input "X".
101-
framework::DataType IndicateDataType(
103+
framework::OpKernelType GetKernelType(
102104
const framework::ExecutionContext& ctx) const override {
103-
return framework::ToDataType(ctx.Input<Tensor>("X")->type());
105+
return framework::OpKernelType(
106+
framework::ToDataType(ctx.Input<Tensor>("X")->type()),
107+
ctx.device_context());
104108
}
105109
};
106110

0 commit comments

Comments
 (0)