@@ -59,11 +59,14 @@ def process_spin_input(
5959 coord = coord .reshape (nframes , nloc , 3 )
6060 spin = spin .reshape (nframes , nloc , 3 )
6161 atype_spin = torch .concat ([atype , atype + self .ntypes_real ], dim = - 1 )
62- virtual_coord = coord + spin * (self .virtual_scale_mask .to (atype .device ))[
63- atype
64- ].reshape ([nframes , nloc , 1 ])
62+ spin_dist = spin * (self .virtual_scale_mask .to (atype .device ))[atype ].reshape (
63+ [nframes , nloc , 1 ]
64+ )
65+ virtual_coord = coord + spin_dist
6566 coord_spin = torch .concat ([coord , virtual_coord ], dim = - 2 )
66- return coord_spin , atype_spin
67+ # for spin virial corr
68+ coord_corr = torch .concat ([torch .zeros_like (coord ), - spin_dist ], dim = - 2 )
69+ return coord_spin , atype_spin , coord_corr
6770
6871 def process_spin_input_lower (
6972 self ,
@@ -83,13 +86,18 @@ def process_spin_input_lower(
8386 """
8487 nframes , nall = extended_coord .shape [:2 ]
8588 nloc = nlist .shape [1 ]
86- virtual_extended_coord = extended_coord + extended_spin * (
89+ extended_spin_dist = extended_spin * (
8790 self .virtual_scale_mask .to (extended_atype .device )
8891 )[extended_atype ].reshape ([nframes , nall , 1 ])
92+ virtual_extended_coord = extended_coord + extended_spin_dist
8993 virtual_extended_atype = extended_atype + self .ntypes_real
9094 extended_coord_updated = concat_switch_virtual (
9195 extended_coord , virtual_extended_coord , nloc
9296 )
97+ # for spin virial corr
98+ extended_coord_corr = concat_switch_virtual (
99+ torch .zeros_like (extended_coord ), - extended_spin_dist , nloc
100+ )
93101 extended_atype_updated = concat_switch_virtual (
94102 extended_atype , virtual_extended_atype , nloc
95103 )
@@ -105,6 +113,7 @@ def process_spin_input_lower(
105113 extended_atype_updated ,
106114 nlist_updated ,
107115 mapping_updated ,
116+ extended_coord_corr ,
108117 )
109118
110119 def process_spin_output (
@@ -376,7 +385,7 @@ def spin_sampled_func() -> list[dict[str, Any]]:
376385 sampled = sampled_func ()
377386 spin_sampled = []
378387 for sys in sampled :
379- coord_updated , atype_updated = self .process_spin_input (
388+ coord_updated , atype_updated , _ = self .process_spin_input (
380389 sys ["coord" ], sys ["atype" ], sys ["spin" ]
381390 )
382391 tmp_dict = {
@@ -407,7 +416,9 @@ def forward_common(
407416 do_atomic_virial : bool = False ,
408417 ) -> dict [str , torch .Tensor ]:
409418 nframes , nloc = atype .shape
410- coord_updated , atype_updated = self .process_spin_input (coord , atype , spin )
419+ coord_updated , atype_updated , coord_corr_for_virial = self .process_spin_input (
420+ coord , atype , spin
421+ )
411422 if aparam is not None :
412423 aparam = self .expand_aparam (aparam , nloc * 2 )
413424 model_ret = self .backbone_model .forward_common (
@@ -417,6 +428,7 @@ def forward_common(
417428 fparam = fparam ,
418429 aparam = aparam ,
419430 do_atomic_virial = do_atomic_virial ,
431+ coord_corr_for_virial = coord_corr_for_virial ,
420432 )
421433 model_output_type = self .backbone_model .model_output_type ()
422434 if "mask" in model_output_type :
@@ -463,6 +475,7 @@ def forward_common_lower(
463475 extended_atype_updated ,
464476 nlist_updated ,
465477 mapping_updated ,
478+ extended_coord_corr_for_virial ,
466479 ) = self .process_spin_input_lower (
467480 extended_coord , extended_atype , extended_spin , nlist , mapping = mapping
468481 )
@@ -478,6 +491,7 @@ def forward_common_lower(
478491 do_atomic_virial = do_atomic_virial ,
479492 comm_dict = comm_dict ,
480493 extra_nlist_sort = extra_nlist_sort ,
494+ extended_coord_corr = extended_coord_corr_for_virial ,
481495 )
482496 model_output_type = self .backbone_model .model_output_type ()
483497 if "mask" in model_output_type :
@@ -550,6 +564,11 @@ def translated_output_def(self) -> dict[str, Any]:
550564 output_def ["force" ].squeeze (- 2 )
551565 output_def ["force_mag" ] = deepcopy (out_def_data ["energy_derv_r_mag" ])
552566 output_def ["force_mag" ].squeeze (- 2 )
567+ if self .do_grad_c ("energy" ):
568+ output_def ["virial" ] = deepcopy (out_def_data ["energy_derv_c_redu" ])
569+ output_def ["virial" ].squeeze (- 2 )
570+ output_def ["atom_virial" ] = deepcopy (out_def_data ["energy_derv_c" ])
571+ output_def ["atom_virial" ].squeeze (- 3 )
553572 return output_def
554573
555574 def forward (
@@ -578,7 +597,10 @@ def forward(
578597 if self .backbone_model .do_grad_r ("energy" ):
579598 model_predict ["force" ] = model_ret ["energy_derv_r" ].squeeze (- 2 )
580599 model_predict ["force_mag" ] = model_ret ["energy_derv_r_mag" ].squeeze (- 2 )
581- # not support virial by far
600+ if self .backbone_model .do_grad_c ("energy" ):
601+ model_predict ["virial" ] = model_ret ["energy_derv_c_redu" ].squeeze (- 2 )
602+ if do_atomic_virial :
603+ model_predict ["atom_virial" ] = model_ret ["energy_derv_c" ].squeeze (- 3 )
582604 return model_predict
583605
584606 @torch .jit .export
@@ -615,5 +637,10 @@ def forward_lower(
615637 model_predict ["extended_force_mag" ] = model_ret [
616638 "energy_derv_r_mag"
617639 ].squeeze (- 2 )
618- # not support virial by far
640+ if self .backbone_model .do_grad_c ("energy" ):
641+ model_predict ["virial" ] = model_ret ["energy_derv_c_redu" ].squeeze (- 2 )
642+ if do_atomic_virial :
643+ model_predict ["extended_virial" ] = model_ret ["energy_derv_c" ].squeeze (
644+ - 3
645+ )
619646 return model_predict
0 commit comments