Skip to content

Commit b84e822

Browse files
authored
Cast Operator (#5149)
* Cast Operator Cast input variable to other data type * Fix compile error * Add cast op * Follow comments
1 parent 46a13e3 commit b84e822

File tree

7 files changed

+222
-1
lines changed

7 files changed

+222
-1
lines changed

paddle/framework/data_type.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,5 +34,25 @@ inline DataType ToDataType(std::type_index type) {
3434
}
3535
}
3636

37+
template <typename Visitor>
38+
inline void VisitDataType(DataType type, Visitor visitor) {
39+
switch (type) {
40+
case DataType::FP32:
41+
visitor.template operator()<float>();
42+
break;
43+
case DataType::FP64:
44+
visitor.template operator()<double>();
45+
break;
46+
case DataType::INT32:
47+
visitor.template operator()<int>();
48+
break;
49+
case DataType::INT64:
50+
visitor.template operator()<int64_t>();
51+
break;
52+
default:
53+
PADDLE_THROW("Not supported");
54+
}
55+
}
56+
3757
} // namespace framework
3858
} // namespace paddle

paddle/framework/op_registry.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,10 @@ class OpKernelRegistrar : public Registrar {
162162
REGISTER_OPERATOR(op_type, op_class, _GradOpDescMaker_##grad_op_type##_, \
163163
op_maker_class);
164164

165+
#define REGISTER_OP_WITH_KERNEL(op_type, ...) \
166+
REGISTER_OPERATOR(op_type, ::paddle::framework::OperatorWithKernel, \
167+
##__VA_ARGS__)
168+
165169
#define REGISTER_OP_WITHOUT_GRADIENT(op_type, op_class, op_maker_class) \
166170
REGISTER_OPERATOR(op_type, op_class, op_maker_class)
167171

paddle/operators/cast_op.cc

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/operators/cast_op.h"
16+
#include "paddle/framework/op_registry.h"
17+
18+
namespace paddle {
19+
namespace operators {
20+
21+
class CastOpProtoMaker : public framework::OpProtoAndCheckerMaker {
22+
public:
23+
CastOpProtoMaker(framework::OpProto *proto,
24+
framework::OpAttrChecker *op_checker)
25+
: OpProtoAndCheckerMaker(proto, op_checker) {
26+
AddInput("X", "the input tensor of cast op");
27+
AddOutput("Out", "the output tensor of cast op");
28+
AddComment(R"DOC(Cast operator.
29+
cast the input tensor to other data type.
30+
)DOC");
31+
AddAttr<int>("out_data_type", "output data type");
32+
AddAttr<int>("in_data_type", "input data type");
33+
}
34+
};
35+
36+
class CastOpInferShape : public framework::InferShapeBase {
37+
public:
38+
void operator()(framework::InferShapeContext *context) const override {
39+
PADDLE_ENFORCE(context->HasInput("X"), "The input of cast op must be set");
40+
PADDLE_ENFORCE(context->HasOutput("Out"),
41+
"The output of cast op must be set");
42+
context->SetOutputDim("Out", context->GetInputDim("X"));
43+
context->ShareLoD("X", "Out");
44+
}
45+
};
46+
47+
class CastOpGradMaker : public framework::SingleGradOpDescMaker {
48+
public:
49+
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
50+
51+
protected:
52+
std::unique_ptr<framework::OpDescBind> Apply() const override {
53+
auto grad = new framework::OpDescBind();
54+
grad->SetType("cast");
55+
grad->SetInput("X", OutputGrad("Out"));
56+
grad->SetOutput("Out", InputGrad("X"));
57+
grad->SetAttr("out_data_type", GetAttr("in_data_type"));
58+
grad->SetAttr("in_data_type", GetAttr("out_data_type"));
59+
return std::unique_ptr<framework::OpDescBind>(grad);
60+
}
61+
};
62+
63+
} // namespace operators
64+
} // namespace paddle
65+
66+
namespace ops = paddle::operators;
67+
using CPU = paddle::platform::CPUPlace;
68+
REGISTER_OP_WITH_KERNEL(cast, ops::CastOpGradMaker, ops::CastOpInferShape,
69+
ops::CastOpProtoMaker);
70+
REGISTER_OP_CPU_KERNEL(cast, ops::CastOpKernel<CPU, float>,
71+
ops::CastOpKernel<CPU, double>,
72+
ops::CastOpKernel<CPU, int>,
73+
ops::CastOpKernel<CPU, int64_t>);

paddle/operators/cast_op.cu

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/operators/cast_op.h"
16+
17+
template <typename T>
18+
using CastOpKernel =
19+
paddle::operators::CastOpKernel<paddle::platform::GPUPlace, T>;
20+
21+
REGISTER_OP_GPU_KERNEL(cast, CastOpKernel<float>, CastOpKernel<double>,
22+
CastOpKernel<int>, CastOpKernel<int64_t>);

paddle/operators/cast_op.h

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
17+
#include "paddle/framework/data_type.h"
18+
#include "paddle/framework/framework.pb.h"
19+
#include "paddle/framework/op_registry.h"
20+
#include "paddle/platform/transform.h"
21+
22+
namespace paddle {
23+
namespace operators {
24+
25+
template <typename InT, typename OutT>
26+
struct CastOpTransformFunctor {
27+
HOSTDEVICE OutT operator()(InT in) const { return static_cast<OutT>(in); }
28+
};
29+
30+
template <typename Place, typename InT>
31+
struct CastOpFunctor {
32+
const framework::Tensor* in_;
33+
framework::Tensor* out_;
34+
const platform::DeviceContext& ctx_;
35+
CastOpFunctor(const framework::Tensor* in, framework::Tensor* out,
36+
const platform::DeviceContext& ctx)
37+
: in_(in), out_(out), ctx_(ctx) {}
38+
39+
template <typename OutT>
40+
void operator()() const {
41+
auto* in_begin = in_->data<InT>();
42+
auto numel = in_->numel();
43+
auto* in_end = in_begin + numel;
44+
auto* out_begin = out_->mutable_data<OutT>(ctx_.GetPlace());
45+
platform::Transform<Place> trans;
46+
trans(ctx_, in_begin, in_end, out_begin,
47+
CastOpTransformFunctor<InT, OutT>());
48+
}
49+
};
50+
51+
template <typename Place, typename InT>
52+
class CastOpKernel : public framework::OpKernel<InT> {
53+
public:
54+
void Compute(const framework::ExecutionContext& context) const override {
55+
auto* in = context.Input<framework::Tensor>("X");
56+
auto* out = context.Output<framework::Tensor>("Out");
57+
framework::VisitDataType(
58+
static_cast<framework::DataType>(context.Attr<int>("out_data_type")),
59+
CastOpFunctor<Place, InT>(in, out, context.device_context()));
60+
}
61+
};
62+
63+
} // namespace operators
64+
} // namespace paddle

python/paddle/v2/framework/layers.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
__all__ = [
77
'fc', 'data', 'cross_entropy', 'conv2d', 'pool2d', 'embedding', 'concat',
8-
'StaticRNN'
8+
'StaticRNN', 'cast'
99
]
1010

1111

@@ -163,6 +163,18 @@ def func(**kwargs):
163163
_create_op_func_('dropout')
164164

165165

166+
def cast(x, data_type, program=None):
167+
helper = LayerHelper('cast', **locals())
168+
out = helper.create_tmp_variable(dtype=data_type)
169+
helper.append_op(
170+
type='cast',
171+
inputs={'X': [x]},
172+
outputs={'Out': [out]},
173+
attrs={'in_data_type': x.data_type,
174+
'out_data_type': out.data_type})
175+
return out
176+
177+
166178
def concat(input, axis, program=None, init_program=None):
167179
helper = LayerHelper('concat', **locals())
168180
if not isinstance(input, list) and not isinstance(input, tuple):
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import op_test
2+
import unittest
3+
import numpy as np
4+
import paddle.v2.framework.core as core
5+
6+
7+
class TestCastOp(op_test.OpTest):
8+
def setUp(self):
9+
ipt = np.random.random(size=[10, 10])
10+
self.inputs = {'X': ipt.astype('float32')}
11+
self.outputs = {'Out': ipt.astype('float64')}
12+
self.attrs = {
13+
'in_data_type': int(core.DataType.FP32),
14+
'out_data_type': int(core.DataType.FP64)
15+
}
16+
self.op_type = 'cast'
17+
18+
def test_check_output(self):
19+
self.check_output()
20+
21+
def test_grad(self):
22+
self.check_grad(['X'], ['Out'])
23+
24+
25+
if __name__ == '__main__':
26+
unittest.main()

0 commit comments

Comments
 (0)