Commit b426d14
fix: make the se attn v2 descriptor energy conservative. (#2905)
This PR fixes issue #2811
1. fix the auto-diff issue: input should be `descriptor_reshape`.
2. fix the discontinuity introduced in the attention map.
---------
Signed-off-by: Jinzhe Zeng <[email protected]>
Co-authored-by: Han Wang <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Jinzhe Zeng <[email protected]>1 parent 66ea4fc commit b426d14
File tree
9 files changed
+360
-46
lines changed- deepmd
- descriptor
- op
- source
- lib
- include
- src
- gpu
- tests
- op
- tests
9 files changed
+360
-46
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
564 | 564 | | |
565 | 565 | | |
566 | 566 | | |
| 567 | + | |
| 568 | + | |
567 | 569 | | |
568 | 570 | | |
569 | 571 | | |
| |||
599 | 601 | | |
600 | 602 | | |
601 | 603 | | |
602 | | - | |
| 604 | + | |
| 605 | + | |
| 606 | + | |
603 | 607 | | |
604 | 608 | | |
605 | 609 | | |
| |||
865 | 869 | | |
866 | 870 | | |
867 | 871 | | |
868 | | - | |
869 | | - | |
| 872 | + | |
| 873 | + | |
| 874 | + | |
| 875 | + | |
| 876 | + | |
| 877 | + | |
| 878 | + | |
| 879 | + | |
| 880 | + | |
| 881 | + | |
| 882 | + | |
870 | 883 | | |
871 | | - | |
| 884 | + | |
| 885 | + | |
| 886 | + | |
| 887 | + | |
| 888 | + | |
| 889 | + | |
| 890 | + | |
| 891 | + | |
872 | 892 | | |
873 | 893 | | |
874 | 894 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
55 | 55 | | |
56 | 56 | | |
57 | 57 | | |
58 | | - | |
| 58 | + | |
59 | 59 | | |
60 | 60 | | |
61 | 61 | | |
| |||
68 | 68 | | |
69 | 69 | | |
70 | 70 | | |
| 71 | + | |
71 | 72 | | |
72 | 73 | | |
73 | 74 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
18 | 18 | | |
19 | 19 | | |
20 | 20 | | |
| 21 | + | |
21 | 22 | | |
22 | 23 | | |
23 | 24 | | |
| |||
38 | 39 | | |
39 | 40 | | |
40 | 41 | | |
| 42 | + | |
41 | 43 | | |
42 | 44 | | |
43 | 45 | | |
| |||
125 | 127 | | |
126 | 128 | | |
127 | 129 | | |
| 130 | + | |
128 | 131 | | |
129 | 132 | | |
130 | 133 | | |
| |||
145 | 148 | | |
146 | 149 | | |
147 | 150 | | |
| 151 | + | |
148 | 152 | | |
149 | 153 | | |
150 | 154 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
253 | 253 | | |
254 | 254 | | |
255 | 255 | | |
| 256 | + | |
256 | 257 | | |
257 | 258 | | |
258 | 259 | | |
| |||
307 | 308 | | |
308 | 309 | | |
309 | 310 | | |
| 311 | + | |
310 | 312 | | |
311 | 313 | | |
312 | 314 | | |
| |||
330 | 332 | | |
331 | 333 | | |
332 | 334 | | |
| 335 | + | |
| 336 | + | |
| 337 | + | |
| 338 | + | |
| 339 | + | |
| 340 | + | |
| 341 | + | |
333 | 342 | | |
334 | 343 | | |
335 | 344 | | |
| |||
357 | 366 | | |
358 | 367 | | |
359 | 368 | | |
| 369 | + | |
360 | 370 | | |
361 | 371 | | |
362 | 372 | | |
| |||
404 | 414 | | |
405 | 415 | | |
406 | 416 | | |
| 417 | + | |
407 | 418 | | |
408 | 419 | | |
409 | 420 | | |
| 421 | + | |
| 422 | + | |
| 423 | + | |
| 424 | + | |
| 425 | + | |
410 | 426 | | |
411 | 427 | | |
412 | 428 | | |
| |||
434 | 450 | | |
435 | 451 | | |
436 | 452 | | |
437 | | - | |
438 | | - | |
| 453 | + | |
| 454 | + | |
439 | 455 | | |
440 | 456 | | |
441 | 457 | | |
| |||
764 | 780 | | |
765 | 781 | | |
766 | 782 | | |
| 783 | + | |
767 | 784 | | |
768 | 785 | | |
769 | 786 | | |
| |||
784 | 801 | | |
785 | 802 | | |
786 | 803 | | |
787 | | - | |
788 | | - | |
789 | | - | |
| 804 | + | |
| 805 | + | |
| 806 | + | |
790 | 807 | | |
791 | 808 | | |
792 | 809 | | |
| |||
800 | 817 | | |
801 | 818 | | |
802 | 819 | | |
| 820 | + | |
803 | 821 | | |
804 | 822 | | |
805 | 823 | | |
| |||
812 | 830 | | |
813 | 831 | | |
814 | 832 | | |
815 | | - | |
| 833 | + | |
816 | 834 | | |
817 | 835 | | |
818 | 836 | | |
| |||
990 | 1008 | | |
991 | 1009 | | |
992 | 1010 | | |
| 1011 | + | |
993 | 1012 | | |
994 | 1013 | | |
995 | 1014 | | |
| |||
1002 | 1021 | | |
1003 | 1022 | | |
1004 | 1023 | | |
| 1024 | + | |
1005 | 1025 | | |
1006 | 1026 | | |
1007 | 1027 | | |
| |||
1021 | 1041 | | |
1022 | 1042 | | |
1023 | 1043 | | |
| 1044 | + | |
1024 | 1045 | | |
1025 | 1046 | | |
1026 | 1047 | | |
| |||
1034 | 1055 | | |
1035 | 1056 | | |
1036 | 1057 | | |
| 1058 | + | |
1037 | 1059 | | |
1038 | 1060 | | |
1039 | 1061 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
158 | 158 | | |
159 | 159 | | |
160 | 160 | | |
| 161 | + | |
161 | 162 | | |
162 | 163 | | |
163 | 164 | | |
| |||
171 | 172 | | |
172 | 173 | | |
173 | 174 | | |
| 175 | + | |
| 176 | + | |
| 177 | + | |
174 | 178 | | |
175 | 179 | | |
176 | 180 | | |
| |||
212 | 216 | | |
213 | 217 | | |
214 | 218 | | |
| 219 | + | |
215 | 220 | | |
216 | 221 | | |
217 | 222 | | |
218 | 223 | | |
219 | 224 | | |
220 | 225 | | |
221 | 226 | | |
| 227 | + | |
222 | 228 | | |
223 | | - | |
| 229 | + | |
224 | 230 | | |
225 | 231 | | |
226 | 232 | | |
227 | 233 | | |
| 234 | + | |
| 235 | + | |
| 236 | + | |
| 237 | + | |
| 238 | + | |
| 239 | + | |
| 240 | + | |
228 | 241 | | |
229 | | - | |
| 242 | + | |
230 | 243 | | |
231 | 244 | | |
232 | 245 | | |
233 | 246 | | |
| 247 | + | |
| 248 | + | |
| 249 | + | |
| 250 | + | |
234 | 251 | | |
235 | 252 | | |
236 | 253 | | |
| |||
250 | 267 | | |
251 | 268 | | |
252 | 269 | | |
| 270 | + | |
253 | 271 | | |
254 | 272 | | |
255 | 273 | | |
| |||
300 | 318 | | |
301 | 319 | | |
302 | 320 | | |
| 321 | + | |
303 | 322 | | |
304 | 323 | | |
305 | 324 | | |
| 325 | + | |
| 326 | + | |
| 327 | + | |
| 328 | + | |
| 329 | + | |
306 | 330 | | |
307 | 331 | | |
308 | 332 | | |
| |||
329 | 353 | | |
330 | 354 | | |
331 | 355 | | |
332 | | - | |
| 356 | + | |
| 357 | + | |
333 | 358 | | |
334 | | - | |
| 359 | + | |
| 360 | + | |
335 | 361 | | |
336 | | - | |
| 362 | + | |
| 363 | + | |
337 | 364 | | |
338 | | - | |
| 365 | + | |
| 366 | + | |
339 | 367 | | |
340 | 368 | | |
341 | | - | |
| 369 | + | |
342 | 370 | | |
343 | | - | |
| 371 | + | |
344 | 372 | | |
345 | | - | |
| 373 | + | |
346 | 374 | | |
347 | | - | |
| 375 | + | |
348 | 376 | | |
349 | 377 | | |
350 | 378 | | |
| |||
660 | 688 | | |
661 | 689 | | |
662 | 690 | | |
| 691 | + | |
663 | 692 | | |
664 | 693 | | |
665 | 694 | | |
| |||
673 | 702 | | |
674 | 703 | | |
675 | 704 | | |
| 705 | + | |
676 | 706 | | |
677 | 707 | | |
678 | 708 | | |
| |||
692 | 722 | | |
693 | 723 | | |
694 | 724 | | |
| 725 | + | |
695 | 726 | | |
696 | 727 | | |
697 | 728 | | |
| |||
705 | 736 | | |
706 | 737 | | |
707 | 738 | | |
| 739 | + | |
708 | 740 | | |
709 | 741 | | |
710 | 742 | | |
| |||
0 commit comments