Skip to content

Commit e66bd4c

Browse files
committed
add GetDataTypeOfVar
1 parent cbfec1f commit e66bd4c

File tree

4 files changed

+19
-31
lines changed

4 files changed

+19
-31
lines changed

paddle/fluid/framework/operator.cc

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,17 @@ std::vector<std::tuple<platform::Place, LibraryType>> kKernelPriority = {
3535
std::make_tuple(platform::CPUPlace(), LibraryType::kPlain),
3636
};
3737

38+
proto::VarType::Type GetDataTypeOfVar(const Variable* var) {
39+
if (var->IsType<framework::LoDTensor>()) {
40+
return framework::ToDataType(var->Get<framework::LoDTensor>().type());
41+
} else if (var->IsType<framework::SelectedRows>()) {
42+
return framework::ToDataType(
43+
var->Get<framework::SelectedRows>().value().type());
44+
} else {
45+
PADDLE_THROW("Var should be LoDTensor or SelectedRows");
46+
}
47+
}
48+
3849
static DDim GetDims(const Scope& scope, const std::string& name) {
3950
Variable* var = scope.FindVar(name);
4051
if (var == nullptr) {

paddle/fluid/framework/operator.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ inline std::string GradVarName(const std::string& var_name) {
6161
return var_name + kGradVarSuffix;
6262
}
6363

64+
proto::VarType::Type GetDataTypeOfVar(const Variable* var);
65+
6466
class OperatorBase;
6567
class ExecutionContext;
6668

paddle/fluid/operators/lookup_table_op.cc

Lines changed: 4 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -18,22 +18,6 @@ limitations under the License. */
1818
namespace paddle {
1919
namespace operators {
2020

21-
static inline framework::OpKernelType ExpectedKernelType(
22-
const framework::ExecutionContext& ctx) {
23-
auto* table_var = ctx.InputVar("W");
24-
if (table_var->IsType<LoDTensor>()) {
25-
return framework::OpKernelType(
26-
framework::ToDataType(table_var->Get<LoDTensor>().type()),
27-
ctx.device_context());
28-
} else if (table_var->IsType<SelectedRows>()) {
29-
return framework::OpKernelType(
30-
framework::ToDataType(table_var->Get<SelectedRows>().value().type()),
31-
ctx.device_context());
32-
} else {
33-
PADDLE_THROW("W should be LoDTensor or SelectedRows");
34-
}
35-
}
36-
3721
class LookupTableOp : public framework::OperatorWithKernel {
3822
public:
3923
using framework::OperatorWithKernel::OperatorWithKernel;
@@ -67,7 +51,8 @@ class LookupTableOp : public framework::OperatorWithKernel {
6751
protected:
6852
framework::OpKernelType GetExpectedKernelType(
6953
const framework::ExecutionContext& ctx) const override {
70-
return ExpectedKernelType(ctx);
54+
auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("W"));
55+
return framework::OpKernelType(data_type, ctx.device_context());
7156
}
7257
};
7358

@@ -138,7 +123,8 @@ class LookupTableOpGrad : public framework::OperatorWithKernel {
138123
protected:
139124
framework::OpKernelType GetExpectedKernelType(
140125
const framework::ExecutionContext& ctx) const override {
141-
return ExpectedKernelType(ctx);
126+
auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("W"));
127+
return framework::OpKernelType(data_type, ctx.device_context());
142128
}
143129
};
144130

paddle/fluid/operators/sgd_op.cc

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -43,19 +43,8 @@ class SGDOp : public framework::OperatorWithKernel {
4343
protected:
4444
framework::OpKernelType GetExpectedKernelType(
4545
const framework::ExecutionContext& ctx) const override {
46-
auto* table_var = ctx.InputVar("Param");
47-
if (table_var->IsType<framework::LoDTensor>()) {
48-
return framework::OpKernelType(
49-
framework::ToDataType(table_var->Get<framework::LoDTensor>().type()),
50-
ctx.device_context());
51-
} else if (table_var->IsType<framework::SelectedRows>()) {
52-
return framework::OpKernelType(
53-
framework::ToDataType(
54-
table_var->Get<framework::SelectedRows>().value().type()),
55-
ctx.device_context());
56-
} else {
57-
PADDLE_THROW("Param should be LoDTensor or SelectedRows");
58-
}
46+
auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("Param"));
47+
return framework::OpKernelType(data_type, ctx.device_context());
5948
}
6049
};
6150

0 commit comments

Comments
 (0)