Skip to content

Commit a939ee4

Browse files
Fix incorrect spin-element mapping in make_unlabeled_stru() for ABACUS STRU generation (#852)
This patch resolves a bug in dpdata/abacus/stru.py where the spin values were misaligned with their corresponding elements when generating an unlabeled STRU file. Root cause: In make_unlabeled_stru(), the code correctly identifies the global atom index (iatomtype) for each element type, but then mistakenly uses a separate running counter (natom_tot) to index every property except atomic coordinates. As a result, spins (and any other per-atom attributes) are pulled from the wrong positions in the input array, breaking the element–spin correspondence. Fix: Replace all usages of natom_tot with iatomtype when accessing per-atom data (spins, velocities, etc.). Coordinates were already using the correct index; the change makes the remaining properties consistent. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Bug Fixes** * Improved accuracy in assigning per-atom properties when generating atomic position data, ensuring correct values are used for each atom. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: SuperBing <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 6b57f35 commit a939ee4

File tree

1 file changed

+30
-28
lines changed

1 file changed

+30
-28
lines changed

dpdata/abacus/stru.py

Lines changed: 30 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -751,66 +751,68 @@ def process_file_input(file_input, atom_names, input_name):
751751
out += "0.0\n"
752752
out += str(data["atom_numbs"][iele]) + "\n"
753753
for iatom in range(data["atom_numbs"][iele]):
754-
iatomtype = np.nonzero(data["atom_types"] == iele)[0][iatom]
754+
iatomtype = np.nonzero(data["atom_types"] == iele)[0][
755+
iatom
756+
] # it is the atom index
755757
iout = f"{data['coords'][frame_idx][iatomtype, 0]:.12f} {data['coords'][frame_idx][iatomtype, 1]:.12f} {data['coords'][frame_idx][iatomtype, 2]:.12f}"
756758
# add flags for move, velocity, mag, angle1, angle2, and sc
757759
if move is not None:
758760
if (
759-
isinstance(ndarray2list(move[natom_tot]), (list, tuple))
760-
and len(move[natom_tot]) == 3
761+
isinstance(ndarray2list(move[iatomtype]), (list, tuple))
762+
and len(move[iatomtype]) == 3
761763
):
762764
iout += " " + " ".join(
763-
["1" if ii else "0" for ii in move[natom_tot]]
765+
["1" if ii else "0" for ii in move[iatomtype]]
764766
)
765-
elif isinstance(ndarray2list(move[natom_tot]), (int, float, bool)):
766-
iout += " 1 1 1" if move[natom_tot] else " 0 0 0"
767+
elif isinstance(ndarray2list(move[iatomtype]), (int, float, bool)):
768+
iout += " 1 1 1" if move[iatomtype] else " 0 0 0"
767769
else:
768770
iout += " 1 1 1"
769771

770772
if (
771773
velocity is not None
772-
and isinstance(ndarray2list(velocity[natom_tot]), (list, tuple))
773-
and len(velocity[natom_tot]) == 3
774+
and isinstance(ndarray2list(velocity[iatomtype]), (list, tuple))
775+
and len(velocity[iatomtype]) == 3
774776
):
775-
iout += " v " + " ".join([f"{ii:.12f}" for ii in velocity[natom_tot]])
777+
iout += " v " + " ".join([f"{ii:.12f}" for ii in velocity[iatomtype]])
776778

777779
if mag is not None:
778-
if isinstance(ndarray2list(mag[natom_tot]), (list, tuple)) and len(
779-
mag[natom_tot]
780+
if isinstance(ndarray2list(mag[iatomtype]), (list, tuple)) and len(
781+
mag[iatomtype]
780782
) in [1, 3]:
781-
iout += " mag " + " ".join([f"{ii:.12f}" for ii in mag[natom_tot]])
782-
elif isinstance(ndarray2list(mag[natom_tot]), (int, float)):
783-
iout += " mag " + f"{mag[natom_tot]:.12f}"
783+
iout += " mag " + " ".join([f"{ii:.12f}" for ii in mag[iatomtype]])
784+
elif isinstance(ndarray2list(mag[iatomtype]), (int, float)):
785+
iout += " mag " + f"{mag[iatomtype]:.12f}"
784786

785787
if angle1 is not None and isinstance(
786-
ndarray2list(angle1[natom_tot]), (int, float)
788+
ndarray2list(angle1[iatomtype]), (int, float)
787789
):
788-
iout += " angle1 " + f"{angle1[natom_tot]:.12f}"
790+
iout += " angle1 " + f"{angle1[iatomtype]:.12f}"
789791

790792
if angle2 is not None and isinstance(
791-
ndarray2list(angle2[natom_tot]), (int, float)
793+
ndarray2list(angle2[iatomtype]), (int, float)
792794
):
793-
iout += " angle2 " + f"{angle2[natom_tot]:.12f}"
795+
iout += " angle2 " + f"{angle2[iatomtype]:.12f}"
794796

795797
if sc is not None:
796-
if isinstance(ndarray2list(sc[natom_tot]), (list, tuple)) and len(
797-
sc[natom_tot]
798+
if isinstance(ndarray2list(sc[iatomtype]), (list, tuple)) and len(
799+
sc[iatomtype]
798800
) in [1, 3]:
799801
iout += " sc " + " ".join(
800-
["1" if ii else "0" for ii in sc[natom_tot]]
802+
["1" if ii else "0" for ii in sc[iatomtype]]
801803
)
802-
elif isinstance(ndarray2list(sc[natom_tot]), (int, float, bool)):
803-
iout += " sc " + "1" if sc[natom_tot] else "0"
804+
elif isinstance(ndarray2list(sc[iatomtype]), (int, float, bool)):
805+
iout += " sc " + "1" if sc[iatomtype] else "0"
804806

805807
if lambda_ is not None:
806-
if isinstance(ndarray2list(lambda_[natom_tot]), (list, tuple)) and len(
807-
lambda_[natom_tot]
808+
if isinstance(ndarray2list(lambda_[iatomtype]), (list, tuple)) and len(
809+
lambda_[iatomtype]
808810
) in [1, 3]:
809811
iout += " lambda " + " ".join(
810-
[f"{ii:.12f}" for ii in lambda_[natom_tot]]
812+
[f"{ii:.12f}" for ii in lambda_[iatomtype]]
811813
)
812-
elif isinstance(ndarray2list(lambda_[natom_tot]), (int, float)):
813-
iout += " lambda " + f"{lambda_[natom_tot]:.12f}"
814+
elif isinstance(ndarray2list(lambda_[iatomtype]), (int, float)):
815+
iout += " lambda " + f"{lambda_[iatomtype]:.12f}"
814816

815817
out += iout + "\n"
816818
natom_tot += 1

0 commit comments

Comments
 (0)