Skip to content

Commit 5ab5fa1

Browse files
liangadamHLA
andauthored
decouple activation function's type from model compression's process in SE_A, now tanh & gelu is both available. (#1020)
* commit-message: decouple activation function's type from model compression's process in SE_A, now tanh & gelu is both available. * commit-message: modified code and passed unittest * commit-message: Format Document * commit-message :Format revert * commit-message: format change * commit-message: Format change Co-authored-by: HLA <[email protected]>
1 parent 8bbe565 commit 5ab5fa1

File tree

4 files changed

+185
-39
lines changed

4 files changed

+185
-39
lines changed

deepmd/descriptor/se_a.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ def __init__ (self,
126126
self.uniform_seed = uniform_seed
127127
self.seed_shift = embedding_net_rand_seed_shift(self.filter_neuron)
128128
self.trainable = trainable
129+
self.compress_activation_fn = get_activation_func(activation_function)
129130
self.filter_activation_fn = get_activation_func(activation_function)
130131
self.filter_precision = get_precision(precision)
131132
self.filter_np_precision = get_np_precision(precision)
@@ -316,7 +317,8 @@ def enable_compression(self,
316317
The overflow check frequency
317318
"""
318319
self.compress = True
319-
self.table = DPTabulate(model_file, self.type_one_side, self.exclude_types)
320+
self.table = DPTabulate(
321+
model_file, self.type_one_side, self.exclude_types, self.compress_activation_fn)
320322
self.table_config = [table_extrapolate, table_stride_1, table_stride_2, check_frequency]
321323
self.lower, self.upper \
322324
= self.table.build(min_nbor_dist,

deepmd/utils/tabulate.py

Lines changed: 37 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,11 @@
22
import math
33
import logging
44
import numpy as np
5+
from typing import Callable
56
from typing import Tuple, List
67
from deepmd.env import tf
78
from deepmd.env import op_module
9+
from deepmd.common import ACTIVATION_FN_DICT
810
from deepmd.utils.sess import run_sess
911
from deepmd.utils.graph import get_tensor_by_name_from_graph, load_graph_def
1012
from deepmd.utils.graph import get_embedding_net_nodes_from_graph_def
@@ -30,11 +32,14 @@ class DPTabulate():
3032
exclude_types : List[List[int]]
3133
The excluded pairs of types which have no interaction with each other.
3234
For example, `[[0, 1]]` means no interaction between type 0 and type 1.
35+
activation_function
36+
The activation function in the embedding net. Supported options are {"tanh","gelu"} in common.ACTIVATION_FN_DICT.
3337
"""
3438
def __init__(self,
3539
model_file : str,
3640
type_one_side : bool = False,
37-
exclude_types : List[List[int]] = []) -> None:
41+
exclude_types : List[List[int]] = [],
42+
activation_fn : Callable[[tf.Tensor], tf.Tensor] = tf.nn.tanh) -> None:
3843
"""
3944
Constructor
4045
"""
@@ -44,6 +49,15 @@ def __init__(self,
4449
self.exclude_types = exclude_types
4550
if self.type_one_side and len(self.exclude_types) != 0:
4651
raise RunTimeError('"type_one_side" is not compatible with "exclude_types"')
52+
53+
# functype
54+
if activation_fn == ACTIVATION_FN_DICT["tanh"]:
55+
self.functype = 1
56+
elif activation_fn == ACTIVATION_FN_DICT["gelu"]:
57+
self.functype = 2
58+
else:
59+
raise RunTimeError("Unknown actication function type!")
60+
self.activation_fn = activation_fn
4761

4862
self.graph, self.graph_def = load_graph_def(self.model_file)
4963
self.sess = tf.Session(graph = self.graph)
@@ -199,26 +213,37 @@ def _make_data(self, xx, idx):
199213
xx = tf.reshape(xx, [xx.size, -1])
200214
for layer in range(self.layer_size):
201215
if layer == 0:
202-
yy = self._layer_0(xx, self.matrix["layer_" + str(layer + 1)][idx], self.bias["layer_" + str(layer + 1)][idx])
203-
dy = op_module.unaggregated_dy_dx_s(yy, self.matrix["layer_" + str(layer + 1)][idx])
204-
dy2 = op_module.unaggregated_dy2_dx_s(yy, dy, self.matrix["layer_" + str(layer + 1)][idx])
216+
xbar = tf.matmul(
217+
xx, self.matrix["layer_" + str(layer + 1)][idx]) + self.bias["layer_" + str(layer + 1)][idx]
218+
yy = self._layer_0(
219+
xx, self.matrix["layer_" + str(layer + 1)][idx], self.bias["layer_" + str(layer + 1)][idx])
220+
dy = op_module.unaggregated_dy_dx_s(
221+
yy, self.matrix["layer_" + str(layer + 1)][idx], xbar, tf.constant(self.functype))
222+
dy2 = op_module.unaggregated_dy2_dx_s(
223+
yy, dy, self.matrix["layer_" + str(layer + 1)][idx], xbar, tf.constant(self.functype))
205224
else:
206-
tt, yy = self._layer_1(yy, self.matrix["layer_" + str(layer + 1)][idx], self.bias["layer_" + str(layer + 1)][idx])
207-
dz = op_module.unaggregated_dy_dx(yy - tt, self.matrix["layer_" + str(layer + 1)][idx], dy)
208-
dy2 = op_module.unaggregated_dy2_dx(yy - tt, self.matrix["layer_" + str(layer + 1)][idx], dz, dy, dy2)
225+
ybar = tf.matmul(
226+
yy, self.matrix["layer_" + str(layer + 1)][idx]) + self.bias["layer_" + str(layer + 1)][idx]
227+
tt, zz = self._layer_1(
228+
yy, self.matrix["layer_" + str(layer + 1)][idx], self.bias["layer_" + str(layer + 1)][idx])
229+
dz = op_module.unaggregated_dy_dx(
230+
zz - tt, self.matrix["layer_" + str(layer + 1)][idx], dy, ybar, tf.constant(self.functype))
231+
dy2 = op_module.unaggregated_dy2_dx(
232+
zz - tt, self.matrix["layer_" + str(layer + 1)][idx], dy, dy2, ybar, tf.constant(self.functype))
209233
dy = dz
210-
211-
vv = yy.eval()
234+
yy = zz
235+
236+
vv = zz.eval()
212237
dd = dy.eval()
213238
d2 = dy2.eval()
214239
return vv, dd, d2
215240

216241
def _layer_0(self, x, w, b):
217-
return tf.nn.tanh(tf.matmul(x, w) + b)
242+
return self.activation_fn(tf.matmul(x, w) + b)
218243

219244
def _layer_1(self, x, w, b):
220-
t = tf.concat([x, x], axis = 1)
221-
return t, tf.nn.tanh(tf.matmul(x, w) + b) + t
245+
t = tf.concat([x, x], axis=1)
246+
return t, self.activation_fn(tf.matmul(x, w) + b) + t
222247

223248
def _save_data(self):
224249
for ii in range(self.ntypes * self.ntypes):

source/op/unaggregated_grad.cc

Lines changed: 93 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,43 +1,90 @@
11
#include "custom_op.h"
22
#include "ComputeDescriptor.h"
33
#include "neighbor_list.h"
4+
#include "device.h"
5+
6+
#define GGELU 0.044715
47

58
REGISTER_OP("UnaggregatedDyDxS")
69
.Attr("T: {float, double} = DT_DOUBLE")
710
.Input("y: T")
8-
.Input("w: T")
11+
.Input("w: T")
12+
.Input("xbar: T")
13+
.Input("functype: int32")
914
.Output("dy_dx: T");
1015

1116
REGISTER_OP("UnaggregatedDyDx")
1217
.Attr("T: {float, double} = DT_DOUBLE")
1318
.Input("z: T")
1419
.Input("w: T")
15-
.Input("dy_dx: T")
20+
.Input("dy_dx: T")
21+
.Input("ybar: T")
22+
.Input("functype: int32")
1623
.Output("dz_dx: T");
1724

1825
REGISTER_OP("UnaggregatedDy2DxS")
1926
.Attr("T: {float, double} = DT_DOUBLE")
2027
.Input("y: T")
2128
.Input("dy: T")
22-
.Input("w: T")
29+
.Input("w: T")
30+
.Input("xbar: T")
31+
.Input("functype: int32")
2332
.Output("dy2_dx: T");
2433

2534
REGISTER_OP("UnaggregatedDy2Dx")
2635
.Attr("T: {float, double} = DT_DOUBLE")
2736
.Input("z: T")
28-
.Input("w: T")
29-
.Input("dz_dx: T")
37+
.Input("w: T")
3038
.Input("dy_dx: T")
3139
.Input("dy2_dx: T")
40+
.Input("ybar: T")
41+
.Input("functype: int32")
3242
.Output("dz2_dx: T");
43+
template <typename FPTYPE>
44+
FPTYPE grad(const FPTYPE xbar, const FPTYPE y, const int functype) //functype=tanh, gelu, ..
45+
{
46+
switch (functype)
47+
{
48+
case 1:
49+
return (1 - y * y);
50+
case 2:
51+
{
52+
const FPTYPE var = tanh(SQRT_2_PI * (xbar + GGELU * xbar * xbar * xbar));
53+
return 0.5 * SQRT_2_PI * xbar * (1 - var * var) * (3 * GGELU * xbar * xbar + 1) + 0.5 * var + 0.5;
54+
}
55+
default:
56+
return -1;
57+
}
58+
59+
}
60+
61+
template <typename FPTYPE>
62+
FPTYPE grad_grad(const FPTYPE xbar, const FPTYPE y, const int functype)
63+
{
64+
switch (functype)
65+
{
66+
case 1:
67+
return -2 * y * (1 - y * y);
68+
case 2:
69+
{
70+
const FPTYPE var1 = tanh(SQRT_2_PI * (xbar + GGELU * xbar * xbar * xbar));
71+
const FPTYPE var2 = SQRT_2_PI * (1 - var1 * var1) * (3 * GGELU * xbar * xbar + 1);
72+
return 3 * GGELU * SQRT_2_PI * xbar * xbar * (1 - var1 * var1) - SQRT_2_PI * xbar * var2 * (3 * GGELU * xbar * xbar + 1) * var1 + var2;
73+
}
74+
default:
75+
return -1;
76+
}
77+
}
78+
79+
3380

3481
template <typename FPTYPE>
3582
struct UnaggregatedDyDxSFunctor {
36-
void operator()(const CPUDevice& d, const FPTYPE * y, const FPTYPE * w, const int length, const int width, FPTYPE * dy_dx) {
83+
void operator()(const CPUDevice& d, const FPTYPE * y, const FPTYPE * w, const FPTYPE* xbar, const int length, const int width, FPTYPE * dy_dx, const int functype) {
3784
#pragma omp parallel for
3885
for (int ii = 0; ii < length; ii++) {
3986
for (int jj = 0; jj < width; jj++) {
40-
dy_dx[ii * width + jj] = (1 - y[ii * width + jj] * y[ii * width + jj]) * w[jj];
87+
dy_dx[ii * width + jj] = grad(xbar[ii * width + jj], y[ii * width + jj],functype)*w[jj];
4188
}
4289
}
4390
}
@@ -53,12 +100,13 @@ struct UnaggregatedDyDxSFunctor {
53100
// calculate the gradient for all variables!
54101
template <typename FPTYPE>
55102
struct UnaggregatedDyDxFunctor {
56-
void operator()(const CPUDevice& d, const FPTYPE * z, const FPTYPE * w, const FPTYPE * dy_dx, const int length, const int width, const int size, FPTYPE * dz_dx) {
103+
void operator()(const CPUDevice& d, const FPTYPE * z, const FPTYPE * w, const FPTYPE * dy_dx, const FPTYPE * ybar, const int length, const int width, const int size, FPTYPE * dz_dx, const int functype) {
104+
//width=2*size
57105
#pragma omp parallel for
58106
for (int kk = 0; kk < length; kk++) {
59107
for (int ii = 0; ii < width; ii++) {
60108
//FPTYPE dz_drou = 1 - (z[kk * width + ii] - y[kk * size + ii % size]) * (z[kk * width + ii] - y[kk * size + ii % size]);
61-
FPTYPE dz_drou = 1 - z[kk * width + ii] * z[kk * width + ii];
109+
FPTYPE dz_drou = grad(ybar[kk*width+ii], z[kk * width + ii],functype);
62110
FPTYPE accumulator = 0.0;
63111
for (int jj = 0; jj < size; jj++) {
64112
accumulator += w[jj * width + ii] * dy_dx[kk * size + jj];
@@ -80,11 +128,11 @@ struct UnaggregatedDyDxFunctor {
80128

81129
template <typename FPTYPE>
82130
struct UnaggregatedDy2DxSFunctor {
83-
void operator()(const CPUDevice& d, const FPTYPE * y, const FPTYPE * dy, const FPTYPE * w, const int length, const int width, FPTYPE * dy2_dx) {
131+
void operator()(const CPUDevice& d, const FPTYPE * y, const FPTYPE * dy, const FPTYPE * w, const FPTYPE* xbar, const int length, const int width, FPTYPE * dy2_dx, const int functype) {
84132
#pragma omp parallel for
85133
for (int ii = 0; ii < length; ii++) {
86134
for (int jj = 0; jj < width; jj++) {
87-
dy2_dx[ii * width + jj] = -2 * w[jj] * y[ii * width + jj] * dy[ii * width + jj];
135+
dy2_dx[ii * width + jj] = grad_grad(xbar[ii * width + jj],y[ii * width + jj],functype)*w[jj]*w[jj];
88136
}
89137
}
90138
}
@@ -100,12 +148,12 @@ struct UnaggregatedDy2DxSFunctor {
100148
// calculate the gradient for all variables!
101149
template <typename FPTYPE>
102150
struct UnaggregatedDy2DxFunctor {
103-
void operator()(const CPUDevice& d, const FPTYPE * z, const FPTYPE * w, const FPTYPE * dz_dx, const FPTYPE * dy_dx, const FPTYPE * dy2_dx, const int length, const int width, const int size, FPTYPE * dz2_dx) {
151+
void operator()(const CPUDevice& d, const FPTYPE * z, const FPTYPE * w, const FPTYPE * dy_dx, const FPTYPE * dy2_dx, const FPTYPE * ybar, const int length, const int width, const int size, FPTYPE * dz2_dx, const int functype) {
104152
#pragma omp parallel for
105153
for (int kk = 0; kk < length; kk++) {
106154
for (int ii = 0; ii < width; ii++) {
107155
//FPTYPE dz_drou = 1 - (z[kk * width + ii] - y[kk * size + ii % size]) * (z[kk * width + ii] - y[kk * size + ii % size]);
108-
FPTYPE dz_drou = 1 - z[kk * width + ii] * z[kk * width + ii];
156+
FPTYPE dz_drou = grad(ybar[kk*width+ii], z[kk * width + ii],functype);
109157
FPTYPE accumulator = 0.0;
110158
for (int jj = 0; jj < size; jj++) {
111159
accumulator += w[jj * width + ii] * dy2_dx[kk * size + jj];
@@ -115,7 +163,7 @@ struct UnaggregatedDy2DxFunctor {
115163
for (int jj = 0; jj < size; jj++) {
116164
accumulator += w[jj * width + ii] * dy_dx[kk * size + jj];
117165
}
118-
dz_drou -= 2 * z[kk * width + ii] * (dz_dx[kk * width + ii] - dy_dx[kk * size + ii % size]) * accumulator;
166+
dz_drou += grad_grad(ybar[kk * width + ii], z[kk * width + ii],functype) * accumulator * accumulator;
119167
dz_drou += dy2_dx[kk * size + ii % size];
120168
dz2_dx[kk * width + ii] = dz_drou;
121169
}
@@ -141,13 +189,18 @@ class UnaggregatedDyDxSOp : public OpKernel {
141189

142190
void _Compute(OpKernelContext* context) {
143191
// Grab the input tensor
192+
//xbar=xw+b
144193
int context_input_index = 0;
145194
const Tensor& y = context->input(context_input_index++);
146195
const Tensor& w = context->input(context_input_index++);
196+
const Tensor& xbar = context->input(context_input_index++);
197+
const Tensor& functype = context->input(context_input_index++);
147198

148199
// set size of the sample
149-
OP_REQUIRES (context, (y.shape().dims() == 2), errors::InvalidArgument ("Dim of table should be 1"));
200+
OP_REQUIRES (context, (y.shape().dims() == 2), errors::InvalidArgument ("Dim of input should be 2"));
150201
OP_REQUIRES (context, (w.shape().dims() == 2), errors::InvalidArgument ("Dim of input should be 2"));
202+
OP_REQUIRES(context, (xbar.shape().dims() == 2), errors::InvalidArgument("Dim of input should be 2"));
203+
//check functype
151204

152205
int context_output_index = 0;
153206
Tensor* dy_dx = NULL;
@@ -159,9 +212,11 @@ class UnaggregatedDyDxSOp : public OpKernel {
159212
context->eigen_device<Device>(), // define actually graph execution device
160213
y.flat<FPTYPE>().data(),
161214
w.flat<FPTYPE>().data(),
215+
xbar.flat<FPTYPE>().data(),
162216
y.shape().dim_size(0),
163217
y.shape().dim_size(1),
164-
dy_dx->flat<FPTYPE>().data()
218+
dy_dx->flat<FPTYPE>().data(),
219+
functype.flat<int32>()(0)
165220
);
166221
}
167222
private:
@@ -182,14 +237,17 @@ class UnaggregatedDy2DxSOp : public OpKernel {
182237
const Tensor& y = context->input(context_input_index++);
183238
const Tensor& dy = context->input(context_input_index++);
184239
const Tensor& w = context->input(context_input_index++);
240+
const Tensor& xbar = context->input(context_input_index++);
241+
const Tensor& functype = context->input(context_input_index++);
185242

186243
// set size of the sample
187244
OP_REQUIRES (context, (y.shape().dims() == 2), errors::InvalidArgument ("Dim of input should be 2"));
188245
OP_REQUIRES (context, (dy.shape().dims() == 2), errors::InvalidArgument ("Dim of input should be 2"));
189246
OP_REQUIRES (context, (w.shape().dims() == 2), errors::InvalidArgument ("Dim of input should be 2"));
247+
OP_REQUIRES (context, (xbar.shape().dims() == 2), errors::InvalidArgument ("Dim of input should be 2"));
190248

191249
int context_output_index = 0;
192-
Tensor* dy2_dx = NULL;
250+
Tensor* dy2_dx = NULL;
193251
OP_REQUIRES_OK(context, context->allocate_output(context_output_index++,
194252
y.shape(),
195253
&dy2_dx));
@@ -199,9 +257,11 @@ class UnaggregatedDy2DxSOp : public OpKernel {
199257
y.flat<FPTYPE>().data(),
200258
dy.flat<FPTYPE>().data(),
201259
w.flat<FPTYPE>().data(),
260+
xbar.flat<FPTYPE>().data(),
202261
y.shape().dim_size(0),
203262
y.shape().dim_size(1),
204-
dy2_dx->flat<FPTYPE>().data()
263+
dy2_dx->flat<FPTYPE>().data(),
264+
functype.flat<int32>()(0)
205265
);
206266
}
207267
private:
@@ -222,11 +282,14 @@ class UnaggregatedDyDxOp : public OpKernel {
222282
const Tensor& z = context->input(context_input_index++);
223283
const Tensor& w = context->input(context_input_index++);
224284
const Tensor& dy_dx = context->input(context_input_index++);
285+
const Tensor& ybar = context->input(context_input_index++);
286+
const Tensor& functype = context->input(context_input_index++);
225287

226288
// set size of the sample
227-
OP_REQUIRES (context, (z.shape().dims() == 2), errors::InvalidArgument ("Dim of table should be 1"));
289+
OP_REQUIRES (context, (z.shape().dims() == 2), errors::InvalidArgument ("Dim of input should be 2"));
228290
OP_REQUIRES (context, (w.shape().dims() == 2), errors::InvalidArgument ("Dim of input should be 2"));
229291
OP_REQUIRES (context, (dy_dx.shape().dims() == 2), errors::InvalidArgument ("Dim of input should be 2"));
292+
OP_REQUIRES (context, (ybar.shape().dims() == 2), errors::InvalidArgument ("Dim of input should be 2"));
230293

231294
int context_output_index = 0;
232295
Tensor* dz_dx = NULL;
@@ -239,10 +302,12 @@ class UnaggregatedDyDxOp : public OpKernel {
239302
z.flat<FPTYPE>().data(),
240303
w.flat<FPTYPE>().data(),
241304
dy_dx.flat<FPTYPE>().data(),
305+
ybar.flat<FPTYPE>().data(),
242306
z.shape().dim_size(0),
243-
z.shape().dim_size(1),
244-
w.shape().dim_size(0),
245-
dz_dx->flat<FPTYPE>().data()
307+
z.shape().dim_size(1), //N1
308+
w.shape().dim_size(0), //N0 , N1=2N0
309+
dz_dx->flat<FPTYPE>().data(),
310+
functype.flat<int32>()(0)
246311
);
247312
}
248313
private:
@@ -262,16 +327,17 @@ class UnaggregatedDy2DxOp : public OpKernel {
262327
int context_input_index = 0;
263328
const Tensor& z = context->input(context_input_index++);
264329
const Tensor& w = context->input(context_input_index++);
265-
const Tensor& dz_dx = context->input(context_input_index++);
266330
const Tensor& dy_dx = context->input(context_input_index++);
267331
const Tensor& dy2_dx = context->input(context_input_index++);
332+
const Tensor& ybar = context->input(context_input_index++);
333+
const Tensor& functype = context->input(context_input_index++);
268334

269335
// set size of the sample
270336
OP_REQUIRES (context, (z.shape().dims() == 2), errors::InvalidArgument ("Dim of input should be 2"));
271337
OP_REQUIRES (context, (w.shape().dims() == 2), errors::InvalidArgument ("Dim of input should be 2"));
272-
OP_REQUIRES (context, (dz_dx.shape().dims() == 2), errors::InvalidArgument ("Dim of input should be 2"));
273338
OP_REQUIRES (context, (dy_dx.shape().dims() == 2), errors::InvalidArgument ("Dim of input should be 2"));
274339
OP_REQUIRES (context, (dy2_dx.shape().dims() == 2), errors::InvalidArgument ("Dim of input should be 2"));
340+
OP_REQUIRES (context, (ybar.shape().dims() == 2), errors::InvalidArgument ("Dim of input should be 2"));
275341

276342
int context_output_index = 0;
277343
Tensor* dz2_dx = NULL;
@@ -283,13 +349,14 @@ class UnaggregatedDy2DxOp : public OpKernel {
283349
context->eigen_device<Device>(), // define actually graph execution device
284350
z.flat<FPTYPE>().data(),
285351
w.flat<FPTYPE>().data(),
286-
dz_dx.flat<FPTYPE>().data(),
287352
dy_dx.flat<FPTYPE>().data(),
288353
dy2_dx.flat<FPTYPE>().data(),
354+
ybar.flat<FPTYPE>().data(),
289355
z.shape().dim_size(0),
290356
z.shape().dim_size(1),
291357
w.shape().dim_size(0),
292-
dz2_dx->flat<FPTYPE>().data()
358+
dz2_dx->flat<FPTYPE>().data(),
359+
functype.flat<int32>()(0)
293360
);
294361
}
295362
private:

0 commit comments

Comments
 (0)