Skip to content

Commit b426d14

Browse files
wanghan-iapcmHan Wangpre-commit-ci[bot]njzjz
authored
fix: make the se attn v2 descriptor energy conservative. (#2905)
This PR fixes issue #2811 1. fix the auto-diff issue: input should be `descriptor_reshape`. 2. fix the discontinuity introduced in the attention map. --------- Signed-off-by: Jinzhe Zeng <[email protected]> Co-authored-by: Han Wang <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jinzhe Zeng <[email protected]>
1 parent 66ea4fc commit b426d14

File tree

9 files changed

+360
-46
lines changed

9 files changed

+360
-46
lines changed

deepmd/descriptor/se_atten.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -564,6 +564,8 @@ def build(
564564
self.filter_precision,
565565
)
566566
self.negative_mask = -(2 << 32) * (1.0 - self.nmask)
567+
# hard coding the magnitude of attention weight shift
568+
self.smth_attn_w_shift = 20.0
567569
# only used when tensorboard was set as true
568570
tf.summary.histogram("descrpt", self.descrpt)
569571
tf.summary.histogram("rij", self.rij)
@@ -599,7 +601,9 @@ def build(
599601
)
600602
self.recovered_r = (
601603
tf.reshape(
602-
tf.slice(tf.reshape(self.descrpt, [-1, 4]), [0, 0], [-1, 1]),
604+
tf.slice(
605+
tf.reshape(self.descrpt_reshape, [-1, 4]), [0, 0], [-1, 1]
606+
),
603607
[-1, natoms[0], self.sel_all_a[0]],
604608
)
605609
* self.std_looked_up
@@ -865,10 +869,26 @@ def _scaled_dot_attn(
865869
save_weights=True,
866870
):
867871
attn = tf.matmul(Q / temperature, K, transpose_b=True)
868-
attn *= self.nmask
869-
attn += self.negative_mask
872+
if self.smooth:
873+
# (nb x nloc) x nsel
874+
nsel = self.sel_all_a[0]
875+
attn = (attn + self.smth_attn_w_shift) * tf.reshape(
876+
self.recovered_switch, [-1, 1, nsel]
877+
) * tf.reshape(
878+
self.recovered_switch, [-1, nsel, 1]
879+
) - self.smth_attn_w_shift
880+
else:
881+
attn *= self.nmask
882+
attn += self.negative_mask
870883
attn = tf.nn.softmax(attn, axis=-1)
871-
attn *= tf.reshape(self.nmask, [-1, attn.shape[-1], 1])
884+
if self.smooth:
885+
attn = (
886+
attn
887+
* tf.reshape(self.recovered_switch, [-1, 1, nsel])
888+
* tf.reshape(self.recovered_switch, [-1, nsel, 1])
889+
)
890+
else:
891+
attn *= tf.reshape(self.nmask, [-1, attn.shape[-1], 1])
872892
if save_weights:
873893
self.attn_weight[layer] = attn[0] # atom 0
874894
if dotr:

deepmd/op/_tabulate_grad.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def _tabulate_fusion_se_atten_grad_cc(op, dy):
5555
op.outputs[0],
5656
is_sorted=op.get_attr("is_sorted"),
5757
)
58-
return [None, None, dy_dx, dy_df, None]
58+
return [None, None, dy_dx, dy_df, dy_dtwo]
5959

6060

6161
@ops.RegisterGradient("TabulateFusionSeAttenGrad")
@@ -68,6 +68,7 @@ def _tabulate_fusion_se_atten_grad_grad_cc(op, dy, dy_, dy_dtwo):
6868
op.inputs[4],
6969
dy,
7070
dy_,
71+
dy_dtwo,
7172
op.inputs[6],
7273
is_sorted=op.get_attr("is_sorted"),
7374
)

source/lib/include/tabulate.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ void tabulate_fusion_se_a_cpu(FPTYPE* out,
1818
template <typename FPTYPE>
1919
void tabulate_fusion_se_a_grad_cpu(FPTYPE* dy_dem_x,
2020
FPTYPE* dy_dem,
21+
FPTYPE* dy_dtwo,
2122
const FPTYPE* table,
2223
const FPTYPE* table_info,
2324
const FPTYPE* em_x,
@@ -38,6 +39,7 @@ void tabulate_fusion_se_a_grad_grad_cpu(FPTYPE* dz_dy,
3839
const FPTYPE* two_embed,
3940
const FPTYPE* dz_dy_dem_x,
4041
const FPTYPE* dz_dy_dem,
42+
const FPTYPE* dz_dy_dtwo,
4143
const int nloc,
4244
const int nnei,
4345
const int last_layer_size,
@@ -125,6 +127,7 @@ void tabulate_fusion_se_a_gpu(FPTYPE* out,
125127
template <typename FPTYPE>
126128
void tabulate_fusion_se_a_grad_gpu(FPTYPE* dy_dem_x,
127129
FPTYPE* dy_dem,
130+
FPTYPE* dy_dtwo,
128131
const FPTYPE* table,
129132
const FPTYPE* table_info,
130133
const FPTYPE* em_x,
@@ -145,6 +148,7 @@ void tabulate_fusion_se_a_grad_grad_gpu(FPTYPE* dz_dy,
145148
const FPTYPE* two_embed,
146149
const FPTYPE* dz_dy_dem_x,
147150
const FPTYPE* dz_dy_dem,
151+
const FPTYPE* dz_dy_dtwo,
148152
const int nloc,
149153
const int nnei,
150154
const int last_layer_size,

source/lib/src/gpu/tabulate.cu

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,7 @@ template <typename FPTYPE, int MTILE, int KTILE>
253253
__global__ void tabulate_fusion_se_a_grad_fifth_order_polynomial(
254254
FPTYPE* dy_dem_x,
255255
FPTYPE* dy_dem,
256+
FPTYPE* dy_dtwo,
256257
const FPTYPE* table,
257258
const FPTYPE* em_x,
258259
const FPTYPE* em,
@@ -307,6 +308,7 @@ __global__ void tabulate_fusion_se_a_grad_fifth_order_polynomial(
307308
(var[1] +
308309
(var[2] + (var[3] + (var[4] + var[5] * xx) * xx) * xx) * xx) *
309310
xx;
311+
FPTYPE oldres = res;
310312
FPTYPE t;
311313
if (enable_se_atten) {
312314
t = two_embed[block_idx * nnei * last_layer_size +
@@ -330,6 +332,13 @@ __global__ void tabulate_fusion_se_a_grad_fifth_order_polynomial(
330332
xx) *
331333
xx) *
332334
(enable_se_atten ? res * t + res : res);
335+
if (enable_se_atten) {
336+
// from ii to ii + (nnei - breakpoint)
337+
for (int ii2 = ii; ii2 < ii + nnei - breakpoint; ii2++) {
338+
dy_dtwo[block_idx * nnei * last_layer_size + ii2 * last_layer_size +
339+
jj] = oldres * res;
340+
}
341+
}
333342
}
334343
GpuSyncThreads();
335344
for (int kk = 0; kk < MTILE; kk++) {
@@ -357,6 +366,7 @@ __global__ void tabulate_fusion_se_a_grad_grad_fifth_order_polynomial(
357366
const FPTYPE* two_embed,
358367
const FPTYPE* dz_dy_dem_x,
359368
const FPTYPE* dz_dy_dem,
369+
const FPTYPE* dz_dy_dtwo,
360370
const FPTYPE lower,
361371
const FPTYPE upper,
362372
const FPTYPE max,
@@ -404,9 +414,15 @@ __global__ void tabulate_fusion_se_a_grad_grad_fifth_order_polynomial(
404414
((FPTYPE)4. * var[4] + (FPTYPE)5. * var[5] * xx) * xx) *
405415
xx) *
406416
xx;
417+
FPTYPE two_grad = 0.;
407418
if (enable_se_atten) {
408419
FPTYPE t = two_embed[block_idx * nnei * last_layer_size +
409420
ii * last_layer_size + thread_idx];
421+
// dz_dy_dtwo * res * em
422+
// res above should be used instead of res + res * t below
423+
two_grad = dz_dy_dtwo[block_idx * nnei * last_layer_size +
424+
ii * last_layer_size + thread_idx] *
425+
res;
410426
res += res * t;
411427
res_grad += res_grad * t;
412428
}
@@ -434,8 +450,8 @@ __global__ void tabulate_fusion_se_a_grad_grad_fifth_order_polynomial(
434450
for (int kk = 0; kk < MTILE; kk++) {
435451
int em_index = block_idx * nnei * MTILE + ii * MTILE + kk;
436452
iteratorC[kk * last_layer_size + thread_idx] +=
437-
(nnei - breakpoint) *
438-
(em[em_index] * res_grad * dz_xx + dz_dy_dem[em_index] * res);
453+
(nnei - breakpoint) * (em[em_index] * (res_grad * dz_xx + two_grad) +
454+
dz_dy_dem[em_index] * res);
439455
}
440456
mark_table_idx = table_idx;
441457
if (unloop) {
@@ -764,6 +780,7 @@ void tabulate_fusion_se_a_gpu(FPTYPE* out,
764780
template <typename FPTYPE>
765781
void tabulate_fusion_se_a_grad_gpu(FPTYPE* dy_dem_x,
766782
FPTYPE* dy_dem,
783+
FPTYPE* dy_dtwo,
767784
const FPTYPE* table,
768785
const FPTYPE* table_info,
769786
const FPTYPE* em_x,
@@ -784,9 +801,9 @@ void tabulate_fusion_se_a_grad_gpu(FPTYPE* dy_dem_x,
784801

785802
tabulate_fusion_se_a_grad_fifth_order_polynomial<FPTYPE, MM, KK>
786803
<<<nloc, KK * WARP_SIZE, sizeof(FPTYPE) * MM * last_layer_size>>>(
787-
dy_dem_x, dy_dem, table, em_x, em, two_embed, dy, table_info[0],
788-
table_info[1], table_info[2], table_info[3], table_info[4], nnei,
789-
last_layer_size, is_sorted);
804+
dy_dem_x, dy_dem, dy_dtwo, table, em_x, em, two_embed, dy,
805+
table_info[0], table_info[1], table_info[2], table_info[3],
806+
table_info[4], nnei, last_layer_size, is_sorted);
790807
DPErrcheck(gpuGetLastError());
791808
DPErrcheck(gpuDeviceSynchronize());
792809
}
@@ -800,6 +817,7 @@ void tabulate_fusion_se_a_grad_grad_gpu(FPTYPE* dz_dy,
800817
const FPTYPE* two_embed,
801818
const FPTYPE* dz_dy_dem_x,
802819
const FPTYPE* dz_dy_dem,
820+
const FPTYPE* dz_dy_dtwo,
803821
const int nloc,
804822
const int nnei,
805823
const int last_layer_size,
@@ -812,7 +830,7 @@ void tabulate_fusion_se_a_grad_grad_gpu(FPTYPE* dz_dy,
812830
DPErrcheck(gpuMemset(dz_dy, 0, sizeof(FPTYPE) * nloc * 4 * last_layer_size));
813831
tabulate_fusion_se_a_grad_grad_fifth_order_polynomial<FPTYPE, MM, KK>
814832
<<<nloc, last_layer_size, sizeof(FPTYPE) * MM * last_layer_size>>>(
815-
dz_dy, table, em_x, em, two_embed, dz_dy_dem_x, dz_dy_dem,
833+
dz_dy, table, em_x, em, two_embed, dz_dy_dem_x, dz_dy_dem, dz_dy_dtwo,
816834
table_info[0], table_info[1], table_info[2], table_info[3],
817835
table_info[4], nnei, last_layer_size, is_sorted);
818836
DPErrcheck(gpuGetLastError());
@@ -990,6 +1008,7 @@ template void tabulate_fusion_se_a_gpu<double>(double* out,
9901008
const bool is_sorted);
9911009
template void tabulate_fusion_se_a_grad_gpu<float>(float* dy_dem_x,
9921010
float* dy_dem,
1011+
float* dy_dtwo,
9931012
const float* table,
9941013
const float* table_info,
9951014
const float* em_x,
@@ -1002,6 +1021,7 @@ template void tabulate_fusion_se_a_grad_gpu<float>(float* dy_dem_x,
10021021
const bool is_sorted);
10031022
template void tabulate_fusion_se_a_grad_gpu<double>(double* dy_dem_x,
10041023
double* dy_dem,
1024+
double* dy_dtwo,
10051025
const double* table,
10061026
const double* table_info,
10071027
const double* em_x,
@@ -1021,6 +1041,7 @@ template void tabulate_fusion_se_a_grad_grad_gpu<float>(
10211041
const float* two_embed,
10221042
const float* dz_dy_dem_x,
10231043
const float* dz_dy_dem,
1044+
const float* dz_dy_dtwo,
10241045
const int nloc,
10251046
const int nnei,
10261047
const int last_layer_size,
@@ -1034,6 +1055,7 @@ template void tabulate_fusion_se_a_grad_grad_gpu<double>(
10341055
const double* two_embed,
10351056
const double* dz_dy_dem_x,
10361057
const double* dz_dy_dem,
1058+
const double* dz_dy_dtwo,
10371059
const int nloc,
10381060
const int nnei,
10391061
const int last_layer_size,

source/lib/src/tabulate.cc

Lines changed: 42 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,7 @@ void deepmd::tabulate_fusion_se_a_cpu(FPTYPE* out,
158158
template <typename FPTYPE>
159159
void deepmd::tabulate_fusion_se_a_grad_cpu(FPTYPE* dy_dem_x,
160160
FPTYPE* dy_dem,
161+
FPTYPE* dy_dtwo,
161162
const FPTYPE* table,
162163
const FPTYPE* table_info,
163164
const FPTYPE* em_x,
@@ -171,6 +172,9 @@ void deepmd::tabulate_fusion_se_a_grad_cpu(FPTYPE* dy_dem_x,
171172
bool enable_se_atten = two_embed != nullptr;
172173
memset(dy_dem_x, 0, sizeof(FPTYPE) * nloc * nnei);
173174
memset(dy_dem, 0, sizeof(FPTYPE) * nloc * nnei * 4);
175+
if (enable_se_atten) {
176+
memset(dy_dtwo, 0, sizeof(FPTYPE) * nloc * nnei * last_layer_size);
177+
}
174178
FPTYPE const lower = table_info[0];
175179
FPTYPE const upper = table_info[1];
176180
FPTYPE const _max = table_info[2];
@@ -212,25 +216,38 @@ void deepmd::tabulate_fusion_se_a_grad_cpu(FPTYPE* dy_dem_x,
212216
a0 + (a1 + (a2 + (a3 + (a4 + a5 * xx) * xx) * xx) * xx) * xx;
213217
FPTYPE g =
214218
(a1 + (2 * a2 + (3 * a3 + (4 * a4 + 5 * a5 * xx) * xx) * xx) * xx);
219+
FPTYPE resold = res;
215220
if (enable_se_atten) {
216221
FPTYPE t = two_embed[ii * nnei * last_layer_size +
217222
jj * last_layer_size + kk];
218223
res = res * t + res;
219224
g += t * g;
220225
}
221226

227+
FPTYPE dotllrr = dot(ll, rr);
222228
if (unloop) {
223-
grad += g * dot(ll, rr) * (nnei - jj);
229+
grad += g * dotllrr * (nnei - jj);
224230
dy_dem[ii * nnei * 4 + jj * 4 + 0] += res * rr[0] * (nnei - jj);
225231
dy_dem[ii * nnei * 4 + jj * 4 + 1] += res * rr[1] * (nnei - jj);
226232
dy_dem[ii * nnei * 4 + jj * 4 + 2] += res * rr[2] * (nnei - jj);
227233
dy_dem[ii * nnei * 4 + jj * 4 + 3] += res * rr[3] * (nnei - jj);
234+
if (enable_se_atten) {
235+
// fill from jj to nnei
236+
for (int jj2 = jj; jj2 < nnei; jj2++) {
237+
dy_dtwo[ii * nnei * last_layer_size + jj2 * last_layer_size +
238+
kk] += resold * dotllrr;
239+
}
240+
}
228241
} else {
229-
grad += g * dot(ll, rr);
242+
grad += g * dotllrr;
230243
dy_dem[ii * nnei * 4 + jj * 4 + 0] += res * rr[0];
231244
dy_dem[ii * nnei * 4 + jj * 4 + 1] += res * rr[1];
232245
dy_dem[ii * nnei * 4 + jj * 4 + 2] += res * rr[2];
233246
dy_dem[ii * nnei * 4 + jj * 4 + 3] += res * rr[3];
247+
if (enable_se_atten) {
248+
dy_dtwo[ii * nnei * last_layer_size + jj * last_layer_size + kk] +=
249+
resold * dotllrr;
250+
}
234251
}
235252
}
236253
dy_dem_x[ii * nnei + jj] = grad;
@@ -250,6 +267,7 @@ void deepmd::tabulate_fusion_se_a_grad_grad_cpu(FPTYPE* dz_dy,
250267
const FPTYPE* two_embed,
251268
const FPTYPE* dz_dy_dem_x,
252269
const FPTYPE* dz_dy_dem,
270+
const FPTYPE* dz_dy_dtwo,
253271
const int nloc,
254272
const int nnei,
255273
const int last_layer_size,
@@ -300,9 +318,15 @@ void deepmd::tabulate_fusion_se_a_grad_grad_cpu(FPTYPE* dz_dy,
300318
((FPTYPE)3. * a3 + ((FPTYPE)4. * a4 + (FPTYPE)5. * a5 * xx) * xx) *
301319
xx) *
302320
xx;
321+
FPTYPE two_grad = 0.;
303322
if (enable_se_atten) {
304323
FPTYPE t = two_embed[ii * nnei * last_layer_size +
305324
jj * last_layer_size + kk];
325+
// dz_dy_dtwo * var * ll
326+
// var above should be used instead of var + var * t below
327+
two_grad = dz_dy_dtwo[ii * nnei * last_layer_size +
328+
jj * last_layer_size + kk] *
329+
var;
306330
var += var * t;
307331
var_grad += var_grad * t;
308332
}
@@ -329,22 +353,26 @@ void deepmd::tabulate_fusion_se_a_grad_grad_cpu(FPTYPE* dz_dy,
329353
*/
330354
if (unloop) {
331355
dz_dy[ii * last_layer_size * 4 + 0 * last_layer_size + kk] +=
332-
(nnei - jj) * (var * hh[0] + dz_xx * var_grad * ll[0]);
356+
(nnei - jj) *
357+
(var * hh[0] + (dz_xx * var_grad + two_grad) * ll[0]);
333358
dz_dy[ii * last_layer_size * 4 + 1 * last_layer_size + kk] +=
334-
(nnei - jj) * (var * hh[1] + dz_xx * var_grad * ll[1]);
359+
(nnei - jj) *
360+
(var * hh[1] + (dz_xx * var_grad + two_grad) * ll[1]);
335361
dz_dy[ii * last_layer_size * 4 + 2 * last_layer_size + kk] +=
336-
(nnei - jj) * (var * hh[2] + dz_xx * var_grad * ll[2]);
362+
(nnei - jj) *
363+
(var * hh[2] + (dz_xx * var_grad + two_grad) * ll[2]);
337364
dz_dy[ii * last_layer_size * 4 + 3 * last_layer_size + kk] +=
338-
(nnei - jj) * (var * hh[3] + dz_xx * var_grad * ll[3]);
365+
(nnei - jj) *
366+
(var * hh[3] + (dz_xx * var_grad + two_grad) * ll[3]);
339367
} else {
340368
dz_dy[ii * last_layer_size * 4 + 0 * last_layer_size + kk] +=
341-
var * hh[0] + dz_xx * var_grad * ll[0];
369+
var * hh[0] + (dz_xx * var_grad + two_grad) * ll[0];
342370
dz_dy[ii * last_layer_size * 4 + 1 * last_layer_size + kk] +=
343-
var * hh[1] + dz_xx * var_grad * ll[1];
371+
var * hh[1] + (dz_xx * var_grad + two_grad) * ll[1];
344372
dz_dy[ii * last_layer_size * 4 + 2 * last_layer_size + kk] +=
345-
var * hh[2] + dz_xx * var_grad * ll[2];
373+
var * hh[2] + (dz_xx * var_grad + two_grad) * ll[2];
346374
dz_dy[ii * last_layer_size * 4 + 3 * last_layer_size + kk] +=
347-
var * hh[3] + dz_xx * var_grad * ll[3];
375+
var * hh[3] + (dz_xx * var_grad + two_grad) * ll[3];
348376
}
349377
}
350378
if (unloop) {
@@ -660,6 +688,7 @@ template void deepmd::tabulate_fusion_se_a_cpu<double>(
660688
template void deepmd::tabulate_fusion_se_a_grad_cpu<float>(
661689
float* dy_dem_x,
662690
float* dy_dem,
691+
float* dy_dtwo,
663692
const float* table,
664693
const float* table_info,
665694
const float* em_x,
@@ -673,6 +702,7 @@ template void deepmd::tabulate_fusion_se_a_grad_cpu<float>(
673702
template void deepmd::tabulate_fusion_se_a_grad_cpu<double>(
674703
double* dy_dem_x,
675704
double* dy_dem,
705+
double* dy_dtwo,
676706
const double* table,
677707
const double* table_info,
678708
const double* em_x,
@@ -692,6 +722,7 @@ template void deepmd::tabulate_fusion_se_a_grad_grad_cpu<float>(
692722
const float* two_embed,
693723
const float* dz_dy_dem_x,
694724
const float* dz_dy_dem,
725+
const float* dz_dy_dtwo,
695726
const int nloc,
696727
const int nnei,
697728
const int last_layer_size,
@@ -705,6 +736,7 @@ template void deepmd::tabulate_fusion_se_a_grad_grad_cpu<double>(
705736
const double* two_embed,
706737
const double* dz_dy_dem_x,
707738
const double* dz_dy_dem,
739+
const double* dz_dy_dtwo,
708740
const int nloc,
709741
const int nnei,
710742
const int last_layer_size,

0 commit comments

Comments
 (0)