1+ #include " tensorflow/core/framework/op.h"
2+ #include " tensorflow/core/framework/op_kernel.h"
3+ #include " tensorflow/core/framework/register_types.h"
4+ #include " tensorflow/core/framework/shape_inference.h"
5+ #define SQRT_2_PI 0.7978845608028654
6+
7+ using namespace tensorflow ;
8+ using CPUDevice = Eigen::ThreadPoolDevice;
9+ using GPUDevice = Eigen::GpuDevice;
10+
11+ REGISTER_OP (" Gelu" )
12+ .Attr(" T: {float, double}" )
13+ .Input(" x: T" )
14+ .Output(" output: T" )
15+ .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
16+ c->set_output (0 , c->input (0 ));
17+ return Status::OK ();
18+ });
19+
20+ REGISTER_OP (" GeluGrad" )
21+ .Attr(" T: {float, double}" )
22+ .Input(" dy: T" )
23+ .Input(" x: T" )
24+ .Output(" output: T" )
25+ .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
26+ c->set_output (0 , c->input (1 ));
27+ return Status::OK ();
28+ });
29+
30+ REGISTER_OP (" GeluGradGrad" )
31+ .Attr(" T: {float, double}" )
32+ .Input(" dy: T" )
33+ .Input(" dy_: T" )
34+ .Input(" x: T" )
35+ .Output(" output: T" )
36+ .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
37+ c->set_output (0 , c->input (2 ));
38+ return Status::OK ();
39+ });
40+
41+ template <typename Device, typename T>
42+ struct GeluFunctor {
43+ void operator ()(const Device& d, const T * in, T * out, int const size) {
44+ #pragma omp parallel for
45+ for (int ii = 0 ; ii < size; ii++) {
46+ out[ii] = in[ii] * 0.5 * (1.0 + tanh (SQRT_2_PI * (in[ii] + 0.044715 * in[ii] * in[ii] * in[ii])));
47+ }
48+ }
49+ };
50+
51+ template <typename Device, typename T>
52+ struct GeluGradFunctor {
53+ void operator ()(const Device& d, const T * dy, const T * in, T * out, int const size) {
54+ #pragma omp parallel for
55+ for (int ii = 0 ; ii < size; ii++) {
56+ T const var1 = tanh (SQRT_2_PI * (in[ii] + 0.044715 * in[ii] * in[ii] *in[ii]));
57+ out[ii] = dy[ii] * (0.5 * SQRT_2_PI * in[ii] * (1 - var1 * var1) * (0.134145 * in[ii] * in[ii] + 1 ) + 0.5 * var1 + 0.5 );
58+ }
59+ }
60+ };
61+
62+ template <typename Device, typename T>
63+ struct GeluGradGradFunctor {
64+ void operator ()(const Device& d, const T * dy, const T * dy_, const T * in, T * out, int const size) {
65+ #pragma omp parallel for
66+ for (int ii = 0 ; ii < size; ii++) {
67+ T const var1 = tanh (SQRT_2_PI * (in[ii] + 0.044715 * in[ii] * in[ii] *in[ii]));
68+ T const var2 = SQRT_2_PI * (1 - var1 * var1) * (0.134145 * in[ii] * in[ii] + 1 );
69+
70+ out[ii] = dy[ii] * dy_[ii] * (0.134145 * SQRT_2_PI * in[ii] * in[ii] * (1 - var1 * var1) - SQRT_2_PI * in[ii] * var2 * (0.134145 * in[ii] * in[ii] + 1 ) * var1 + var2);
71+ }
72+ }
73+ };
74+
75+ // OpKernel definition.
76+ // template parameter <T> is the datatype of the tensors.
77+ template <typename Device, typename T>
78+ class GeluOp : public OpKernel {
79+ public :
80+ explicit GeluOp (OpKernelConstruction* context) : OpKernel(context) {}
81+
82+ void Compute (OpKernelContext* context) override {
83+ // Grab the input tensor
84+ const Tensor& x = context->input (0 );
85+
86+ Tensor * output = NULL ;
87+ int context_output_index = 0 ;
88+ OP_REQUIRES_OK (context, context->allocate_output (context_output_index++,
89+ x.shape (),
90+ &output));
91+
92+ GeluFunctor<Device, T>()(
93+ context->eigen_device <Device>(),
94+ x.flat <T>().data (),
95+ output->flat <T>().data (),
96+ static_cast <int >(output->NumElements ())
97+ );
98+ // GeluLauncher(x.flat<T>().data(), output->flat<T>().data(), static_cast<int>(output->NumElements()));
99+ }
100+ };
101+
102+ // OpKernel definition.
103+ // template parameter <T> is the datatype of the tensors.
104+ template <typename Device, typename T>
105+ class GeluGradOp : public OpKernel {
106+ public :
107+ explicit GeluGradOp (OpKernelConstruction* context) : OpKernel(context) {}
108+
109+ void Compute (OpKernelContext* context) override {
110+ // Grab the input tensor
111+ const Tensor& dy = context->input (0 );
112+ const Tensor& x = context->input (1 );
113+
114+ Tensor * output = NULL ;
115+ int context_output_index = 0 ;
116+ OP_REQUIRES_OK (context, context->allocate_output (context_output_index++,
117+ x.shape (),
118+ &output));
119+
120+ GeluGradFunctor<Device, T>()(
121+ context->eigen_device <Device>(),
122+ dy.flat <T>().data (),
123+ x.flat <T>().data (),
124+ output->flat <T>().data (),
125+ static_cast <int >(output->NumElements ())
126+ );
127+ // GeluGradLauncher(dy.flat<T>().data(), x.flat<T>().data(), output->flat<T>().data(), static_cast<int>(output->NumElements()));
128+ }
129+ };
130+
131+ // OpKernel definition.
132+ // template parameter <T> is the datatype of the tensors.
133+ template <typename Device, typename T>
134+ class GeluGradGradOp : public OpKernel {
135+ public :
136+ explicit GeluGradGradOp (OpKernelConstruction* context) : OpKernel(context) {}
137+
138+ void Compute (OpKernelContext* context) override {
139+ // Grab the input tensor
140+ const Tensor& dy = context->input (0 );
141+ const Tensor& dy_ = context->input (1 );
142+ const Tensor& x = context->input (2 );
143+
144+ Tensor * output = NULL ;
145+ int context_output_index = 0 ;
146+ OP_REQUIRES_OK (context, context->allocate_output (context_output_index++,
147+ x.shape (),
148+ &output));
149+
150+ GeluGradGradFunctor<Device, T>()(
151+ context->eigen_device <Device>(),
152+ dy.flat <T>().data (),
153+ dy_.flat <T>().data (),
154+ x.flat <T>().data (),
155+ output->flat <T>().data (),
156+ static_cast <int >(output->NumElements ())
157+ );
158+ // GeluGradGradLauncher(dy.flat<T>().data(), x.flat<T>().data(), output->flat<T>().data(), static_cast<int>(output->NumElements()));
159+ }
160+ };
161+
162+ #define REGISTER_CPU (T ) \
163+ /* Declare explicit instantiations in kernel_example.cu.cc. */ \
164+ REGISTER_KERNEL_BUILDER ( \
165+ Name (" Gelu" ).Device(DEVICE_CPU).TypeConstraint<T>(" T" ), \
166+ GeluOp<CPUDevice, T>); \
167+ /* Declare explicit instantiations in kernel_example.cu.cc. */ \
168+ REGISTER_KERNEL_BUILDER ( \
169+ Name (" GeluGrad" ).Device(DEVICE_CPU).TypeConstraint<T>(" T" ), \
170+ GeluGradOp<CPUDevice, T>); \
171+ /* Declare explicit instantiations in kernel_example.cu.cc. */ \
172+ REGISTER_KERNEL_BUILDER ( \
173+ Name (" GeluGradGrad" ).Device(DEVICE_CPU).TypeConstraint<T>(" T" ), \
174+ GeluGradGradOp<CPUDevice, T>);
175+ REGISTER_CPU (float );
176+ REGISTER_CPU (double );
0 commit comments