11#include " custom_op.h"
22#include " ComputeDescriptor.h"
33#include " neighbor_list.h"
4+ #include " device.h"
5+
6+ #define GGELU 0.044715
47
58REGISTER_OP (" UnaggregatedDyDxS" )
69 .Attr(" T: {float, double} = DT_DOUBLE" )
710 .Input(" y: T" )
8- .Input(" w: T" )
11+ .Input(" w: T" )
12+ .Input(" xbar: T" )
13+ .Input(" functype: int32" )
914 .Output(" dy_dx: T" );
1015
1116REGISTER_OP (" UnaggregatedDyDx" )
1217 .Attr(" T: {float, double} = DT_DOUBLE" )
1318 .Input(" z: T" )
1419 .Input(" w: T" )
15- .Input(" dy_dx: T" )
20+ .Input(" dy_dx: T" )
21+ .Input(" ybar: T" )
22+ .Input(" functype: int32" )
1623 .Output(" dz_dx: T" );
1724
1825REGISTER_OP (" UnaggregatedDy2DxS" )
1926 .Attr(" T: {float, double} = DT_DOUBLE" )
2027 .Input(" y: T" )
2128 .Input(" dy: T" )
22- .Input(" w: T" )
29+ .Input(" w: T" )
30+ .Input(" xbar: T" )
31+ .Input(" functype: int32" )
2332 .Output(" dy2_dx: T" );
2433
2534REGISTER_OP (" UnaggregatedDy2Dx" )
2635 .Attr(" T: {float, double} = DT_DOUBLE" )
2736 .Input(" z: T" )
28- .Input(" w: T" )
29- .Input(" dz_dx: T" )
37+ .Input(" w: T" )
3038 .Input(" dy_dx: T" )
3139 .Input(" dy2_dx: T" )
40+ .Input(" ybar: T" )
41+ .Input(" functype: int32" )
3242 .Output(" dz2_dx: T" );
43+ template <typename FPTYPE>
44+ FPTYPE grad (const FPTYPE xbar, const FPTYPE y, const int functype) // functype=tanh, gelu, ..
45+ {
46+ switch (functype)
47+ {
48+ case 1 :
49+ return (1 - y * y);
50+ case 2 :
51+ {
52+ const FPTYPE var = tanh (SQRT_2_PI * (xbar + GGELU * xbar * xbar * xbar));
53+ return 0.5 * SQRT_2_PI * xbar * (1 - var * var) * (3 * GGELU * xbar * xbar + 1 ) + 0.5 * var + 0.5 ;
54+ }
55+ default :
56+ return -1 ;
57+ }
58+
59+ }
60+
61+ template <typename FPTYPE>
62+ FPTYPE grad_grad (const FPTYPE xbar, const FPTYPE y, const int functype)
63+ {
64+ switch (functype)
65+ {
66+ case 1 :
67+ return -2 * y * (1 - y * y);
68+ case 2 :
69+ {
70+ const FPTYPE var1 = tanh (SQRT_2_PI * (xbar + GGELU * xbar * xbar * xbar));
71+ const FPTYPE var2 = SQRT_2_PI * (1 - var1 * var1) * (3 * GGELU * xbar * xbar + 1 );
72+ return 3 * GGELU * SQRT_2_PI * xbar * xbar * (1 - var1 * var1) - SQRT_2_PI * xbar * var2 * (3 * GGELU * xbar * xbar + 1 ) * var1 + var2;
73+ }
74+ default :
75+ return -1 ;
76+ }
77+ }
78+
79+
3380
3481template <typename FPTYPE>
3582struct UnaggregatedDyDxSFunctor {
36- void operator ()(const CPUDevice& d, const FPTYPE * y, const FPTYPE * w, const int length, const int width, FPTYPE * dy_dx) {
83+ void operator ()(const CPUDevice& d, const FPTYPE * y, const FPTYPE * w, const FPTYPE* xbar, const int length, const int width, FPTYPE * dy_dx, const int functype ) {
3784 #pragma omp parallel for
3885 for (int ii = 0 ; ii < length; ii++) {
3986 for (int jj = 0 ; jj < width; jj++) {
40- dy_dx[ii * width + jj] = ( 1 - y [ii * width + jj] * y[ii * width + jj]) * w[jj];
87+ dy_dx[ii * width + jj] = grad (xbar [ii * width + jj], y[ii * width + jj],functype)* w[jj];
4188 }
4289 }
4390 }
@@ -53,12 +100,13 @@ struct UnaggregatedDyDxSFunctor {
53100// calculate the gradient for all variables!
54101template <typename FPTYPE>
55102struct UnaggregatedDyDxFunctor {
56- void operator ()(const CPUDevice& d, const FPTYPE * z, const FPTYPE * w, const FPTYPE * dy_dx, const int length, const int width, const int size, FPTYPE * dz_dx) {
103+ void operator ()(const CPUDevice& d, const FPTYPE * z, const FPTYPE * w, const FPTYPE * dy_dx, const FPTYPE * ybar, const int length, const int width, const int size, FPTYPE * dz_dx, const int functype) {
104+ // width=2*size
57105 #pragma omp parallel for
58106 for (int kk = 0 ; kk < length; kk++) {
59107 for (int ii = 0 ; ii < width; ii++) {
60108 // FPTYPE dz_drou = 1 - (z[kk * width + ii] - y[kk * size + ii % size]) * (z[kk * width + ii] - y[kk * size + ii % size]);
61- FPTYPE dz_drou = 1 - z [kk * width + ii] * z[kk * width + ii];
109+ FPTYPE dz_drou = grad (ybar [kk* width+ ii], z[kk * width + ii],functype) ;
62110 FPTYPE accumulator = 0.0 ;
63111 for (int jj = 0 ; jj < size; jj++) {
64112 accumulator += w[jj * width + ii] * dy_dx[kk * size + jj];
@@ -80,11 +128,11 @@ struct UnaggregatedDyDxFunctor {
80128
81129template <typename FPTYPE>
82130struct UnaggregatedDy2DxSFunctor {
83- void operator ()(const CPUDevice& d, const FPTYPE * y, const FPTYPE * dy, const FPTYPE * w, const int length, const int width, FPTYPE * dy2_dx) {
131+ void operator ()(const CPUDevice& d, const FPTYPE * y, const FPTYPE * dy, const FPTYPE * w, const FPTYPE* xbar, const int length, const int width, FPTYPE * dy2_dx, const int functype ) {
84132 #pragma omp parallel for
85133 for (int ii = 0 ; ii < length; ii++) {
86134 for (int jj = 0 ; jj < width; jj++) {
87- dy2_dx[ii * width + jj] = - 2 * w[jj] * y[ ii * width + jj] * dy [ii * width + jj];
135+ dy2_dx[ii * width + jj] = grad_grad (xbar[ ii * width + jj],y [ii * width + jj],functype)*w[jj]*w[ jj];
88136 }
89137 }
90138 }
@@ -100,12 +148,12 @@ struct UnaggregatedDy2DxSFunctor {
100148// calculate the gradient for all variables!
101149template <typename FPTYPE>
102150struct UnaggregatedDy2DxFunctor {
103- void operator ()(const CPUDevice& d, const FPTYPE * z, const FPTYPE * w, const FPTYPE * dz_dx , const FPTYPE * dy_dx , const FPTYPE * dy2_dx , const int length, const int width, const int size, FPTYPE * dz2_dx) {
151+ void operator ()(const CPUDevice& d, const FPTYPE * z, const FPTYPE * w, const FPTYPE * dy_dx , const FPTYPE * dy2_dx , const FPTYPE * ybar , const int length, const int width, const int size, FPTYPE * dz2_dx, const int functype ) {
104152 #pragma omp parallel for
105153 for (int kk = 0 ; kk < length; kk++) {
106154 for (int ii = 0 ; ii < width; ii++) {
107155 // FPTYPE dz_drou = 1 - (z[kk * width + ii] - y[kk * size + ii % size]) * (z[kk * width + ii] - y[kk * size + ii % size]);
108- FPTYPE dz_drou = 1 - z [kk * width + ii] * z[kk * width + ii];
156+ FPTYPE dz_drou = grad (ybar [kk* width+ ii], z[kk * width + ii],functype) ;
109157 FPTYPE accumulator = 0.0 ;
110158 for (int jj = 0 ; jj < size; jj++) {
111159 accumulator += w[jj * width + ii] * dy2_dx[kk * size + jj];
@@ -115,7 +163,7 @@ struct UnaggregatedDy2DxFunctor {
115163 for (int jj = 0 ; jj < size; jj++) {
116164 accumulator += w[jj * width + ii] * dy_dx[kk * size + jj];
117165 }
118- dz_drou -= 2 * z [kk * width + ii] * (dz_dx [kk * width + ii] - dy_dx[kk * size + ii % size]) * accumulator;
166+ dz_drou += grad_grad (ybar [kk * width + ii], z [kk * width + ii],functype) * accumulator * accumulator;
119167 dz_drou += dy2_dx[kk * size + ii % size];
120168 dz2_dx[kk * width + ii] = dz_drou;
121169 }
@@ -141,13 +189,18 @@ class UnaggregatedDyDxSOp : public OpKernel {
141189
142190 void _Compute (OpKernelContext* context) {
143191 // Grab the input tensor
192+ // xbar=xw+b
144193 int context_input_index = 0 ;
145194 const Tensor& y = context->input (context_input_index++);
146195 const Tensor& w = context->input (context_input_index++);
196+ const Tensor& xbar = context->input (context_input_index++);
197+ const Tensor& functype = context->input (context_input_index++);
147198
148199 // set size of the sample
149- OP_REQUIRES (context, (y.shape ().dims () == 2 ), errors::InvalidArgument (" Dim of table should be 1 " ));
200+ OP_REQUIRES (context, (y.shape ().dims () == 2 ), errors::InvalidArgument (" Dim of input should be 2 " ));
150201 OP_REQUIRES (context, (w.shape ().dims () == 2 ), errors::InvalidArgument (" Dim of input should be 2" ));
202+ OP_REQUIRES (context, (xbar.shape ().dims () == 2 ), errors::InvalidArgument (" Dim of input should be 2" ));
203+ // check functype
151204
152205 int context_output_index = 0 ;
153206 Tensor* dy_dx = NULL ;
@@ -159,9 +212,11 @@ class UnaggregatedDyDxSOp : public OpKernel {
159212 context->eigen_device <Device>(), // define actually graph execution device
160213 y.flat <FPTYPE>().data (),
161214 w.flat <FPTYPE>().data (),
215+ xbar.flat <FPTYPE>().data (),
162216 y.shape ().dim_size (0 ),
163217 y.shape ().dim_size (1 ),
164- dy_dx->flat <FPTYPE>().data ()
218+ dy_dx->flat <FPTYPE>().data (),
219+ functype.flat <int32>()(0 )
165220 );
166221 }
167222private:
@@ -182,14 +237,17 @@ class UnaggregatedDy2DxSOp : public OpKernel {
182237 const Tensor& y = context->input (context_input_index++);
183238 const Tensor& dy = context->input (context_input_index++);
184239 const Tensor& w = context->input (context_input_index++);
240+ const Tensor& xbar = context->input (context_input_index++);
241+ const Tensor& functype = context->input (context_input_index++);
185242
186243 // set size of the sample
187244 OP_REQUIRES (context, (y.shape ().dims () == 2 ), errors::InvalidArgument (" Dim of input should be 2" ));
188245 OP_REQUIRES (context, (dy.shape ().dims () == 2 ), errors::InvalidArgument (" Dim of input should be 2" ));
189246 OP_REQUIRES (context, (w.shape ().dims () == 2 ), errors::InvalidArgument (" Dim of input should be 2" ));
247+ OP_REQUIRES (context, (xbar.shape ().dims () == 2 ), errors::InvalidArgument (" Dim of input should be 2" ));
190248
191249 int context_output_index = 0 ;
192- Tensor* dy2_dx = NULL ;
250+ Tensor* dy2_dx = NULL ;
193251 OP_REQUIRES_OK (context, context->allocate_output (context_output_index++,
194252 y.shape (),
195253 &dy2_dx));
@@ -199,9 +257,11 @@ class UnaggregatedDy2DxSOp : public OpKernel {
199257 y.flat <FPTYPE>().data (),
200258 dy.flat <FPTYPE>().data (),
201259 w.flat <FPTYPE>().data (),
260+ xbar.flat <FPTYPE>().data (),
202261 y.shape ().dim_size (0 ),
203262 y.shape ().dim_size (1 ),
204- dy2_dx->flat <FPTYPE>().data ()
263+ dy2_dx->flat <FPTYPE>().data (),
264+ functype.flat <int32>()(0 )
205265 );
206266 }
207267private:
@@ -222,11 +282,14 @@ class UnaggregatedDyDxOp : public OpKernel {
222282 const Tensor& z = context->input (context_input_index++);
223283 const Tensor& w = context->input (context_input_index++);
224284 const Tensor& dy_dx = context->input (context_input_index++);
285+ const Tensor& ybar = context->input (context_input_index++);
286+ const Tensor& functype = context->input (context_input_index++);
225287
226288 // set size of the sample
227- OP_REQUIRES (context, (z.shape ().dims () == 2 ), errors::InvalidArgument (" Dim of table should be 1 " ));
289+ OP_REQUIRES (context, (z.shape ().dims () == 2 ), errors::InvalidArgument (" Dim of input should be 2 " ));
228290 OP_REQUIRES (context, (w.shape ().dims () == 2 ), errors::InvalidArgument (" Dim of input should be 2" ));
229291 OP_REQUIRES (context, (dy_dx.shape ().dims () == 2 ), errors::InvalidArgument (" Dim of input should be 2" ));
292+ OP_REQUIRES (context, (ybar.shape ().dims () == 2 ), errors::InvalidArgument (" Dim of input should be 2" ));
230293
231294 int context_output_index = 0 ;
232295 Tensor* dz_dx = NULL ;
@@ -239,10 +302,12 @@ class UnaggregatedDyDxOp : public OpKernel {
239302 z.flat <FPTYPE>().data (),
240303 w.flat <FPTYPE>().data (),
241304 dy_dx.flat <FPTYPE>().data (),
305+ ybar.flat <FPTYPE>().data (),
242306 z.shape ().dim_size (0 ),
243- z.shape ().dim_size (1 ),
244- w.shape ().dim_size (0 ),
245- dz_dx->flat <FPTYPE>().data ()
307+ z.shape ().dim_size (1 ), // N1
308+ w.shape ().dim_size (0 ), // N0 , N1=2N0
309+ dz_dx->flat <FPTYPE>().data (),
310+ functype.flat <int32>()(0 )
246311 );
247312 }
248313private:
@@ -262,16 +327,17 @@ class UnaggregatedDy2DxOp : public OpKernel {
262327 int context_input_index = 0 ;
263328 const Tensor& z = context->input (context_input_index++);
264329 const Tensor& w = context->input (context_input_index++);
265- const Tensor& dz_dx = context->input (context_input_index++);
266330 const Tensor& dy_dx = context->input (context_input_index++);
267331 const Tensor& dy2_dx = context->input (context_input_index++);
332+ const Tensor& ybar = context->input (context_input_index++);
333+ const Tensor& functype = context->input (context_input_index++);
268334
269335 // set size of the sample
270336 OP_REQUIRES (context, (z.shape ().dims () == 2 ), errors::InvalidArgument (" Dim of input should be 2" ));
271337 OP_REQUIRES (context, (w.shape ().dims () == 2 ), errors::InvalidArgument (" Dim of input should be 2" ));
272- OP_REQUIRES (context, (dz_dx.shape ().dims () == 2 ), errors::InvalidArgument (" Dim of input should be 2" ));
273338 OP_REQUIRES (context, (dy_dx.shape ().dims () == 2 ), errors::InvalidArgument (" Dim of input should be 2" ));
274339 OP_REQUIRES (context, (dy2_dx.shape ().dims () == 2 ), errors::InvalidArgument (" Dim of input should be 2" ));
340+ OP_REQUIRES (context, (ybar.shape ().dims () == 2 ), errors::InvalidArgument (" Dim of input should be 2" ));
275341
276342 int context_output_index = 0 ;
277343 Tensor* dz2_dx = NULL ;
@@ -283,13 +349,14 @@ class UnaggregatedDy2DxOp : public OpKernel {
283349 context->eigen_device <Device>(), // define actually graph execution device
284350 z.flat <FPTYPE>().data (),
285351 w.flat <FPTYPE>().data (),
286- dz_dx.flat <FPTYPE>().data (),
287352 dy_dx.flat <FPTYPE>().data (),
288353 dy2_dx.flat <FPTYPE>().data (),
354+ ybar.flat <FPTYPE>().data (),
289355 z.shape ().dim_size (0 ),
290356 z.shape ().dim_size (1 ),
291357 w.shape ().dim_size (0 ),
292- dz2_dx->flat <FPTYPE>().data ()
358+ dz2_dx->flat <FPTYPE>().data (),
359+ functype.flat <int32>()(0 )
293360 );
294361 }
295362private:
0 commit comments