Skip to content
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion dptb/nnops/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,8 @@ def forward(
# hopping_loss += self.loss1(pre, tgt) + torch.sqrt(self.loss2(pre, tgt))

# return hopping_loss + onsite_loss




@Loss.register("hamil_abs")
Expand Down Expand Up @@ -1004,4 +1006,4 @@ def __cal_norm__(self, irreps: Irreps, x: torch.Tensor):
tensor = tensor.norm(dim=-1)
out.append(tensor)

return torch.cat(out, dim=-1).squeeze(0)
return torch.cat(out, dim=-1).squeeze(0)
9 changes: 7 additions & 2 deletions dptb/nnops/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,16 @@ def __init__(
# get the task from train_datasets label
self.task = None
if self.train_datasets.get_Hamiltonian:
self.task = "hamiltonians"
if self.train_datasets.get_eigenvalues:
self.task = "hamil_eigvals"
else:
self.task = "hamiltonians"
elif self.train_datasets.get_DM:
self.task = "DM"
else:
elif self.train_datasets.get_eigenvalues:
self.task = "eigenvalues"
else:
raise RuntimeError("The train data set should have at least one of get_Hamiltonian, get_DM or get_eigenvalues set to True.")

self.use_reference = False
if reference_datasets is not None:
Expand Down
8 changes: 7 additions & 1 deletion dptb/utils/argcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -837,6 +837,11 @@ def loss_options():
Argument("spin_deg", int, optional=True, default=2, doc="The spin degeneracy of band structure. Default: 2"),
]

eig_ham = [
Argument("coeff_ham", float, optional=True, default=1., doc="The coefficient of the hamiltonian penalty. Default: 1"),
Argument("coeff_ovp", float, optional=True, default=1., doc="The coefficient of the overlap penalty. Default: 1"),
]

skints = [
Argument("skdata", str, optional=False, doc="The path to the skfile or sk database."),
]
Expand All @@ -848,6 +853,7 @@ def loss_options():
Argument("hamil_abs", dict, sub_fields=hamil),
Argument("hamil_blas", dict, sub_fields=hamil),
Argument("hamil_wt", dict, sub_fields=hamil+wt),
Argument("eig_ham", dict, sub_fields=hamil+eigvals+eig_ham),
], optional=False, doc=doc_method)


Expand Down Expand Up @@ -1750,4 +1756,4 @@ def normalize_skf2nnsk(data):
base.check_value(data, strict=True)

return data


4 changes: 2 additions & 2 deletions dptb/utils/config_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ def check_config_train(
if train_data_config.get("get_Hamiltonian") and not train_data_config.get("get_eigenvalues"):
assert jdata['train_options']['loss_options']['train'].get("method").startswith("hamil")

if train_data_config.get("get_Hamiltonian") and train_data_config.get("get_eigenvalues"):
raise RuntimeError("The train data set should not have both get_Hamiltonian and get_eigenvalues set to True.")
# if train_data_config.get("get_Hamiltonian") and train_data_config.get("get_eigenvalues"):
# raise RuntimeError("The train data set should not have both get_Hamiltonian and get_eigenvalues set to True.")
Comment on lines 40 to 41
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Description: Check for code that assumes mutual exclusivity of get_Hamiltonian and get_eigenvalues

# Search for conditional logic that checks both flags
echo "=== Checking for conditional logic with both flags ==="
rg -nP --type=py -C3 'get_Hamiltonian.*get_eigenvalues|get_eigenvalues.*get_Hamiltonian'

echo ""
echo "=== Searching for eig_ham and EigHamLoss references ==="
# Search for loss method registration and usage
rg -nP --type=py -C3 'eig_ham|EigHamLoss'

Repository: deepmodeling/DeePTB

Length of output: 3591


🏁 Script executed:

#!/bin/bash

# Check the full config_check.py file around the guard area
echo "=== Full context of config_check.py lines 30-50 ==="
sed -n '30,50p' dptb/utils/config_check.py

# Check if there's any other validation for loss method selection
echo ""
echo "=== Searching for loss method validation logic ==="
rg -nP --type=py -C5 'loss_options.*method|train_options.*method' dptb/utils/config_check.py

# Check argcheck.py loss method validation structure
echo ""
echo "=== Loss method validation in argcheck.py ==="
rg -nP --type=py -B5 -A15 'eig_ham.*='

Repository: deepmodeling/DeePTB

Length of output: 5056


🏁 Script executed:

#!/bin/bash

# Check if there's validation that ensures loss method matches data config
echo "=== Checking for loss method validation elsewhere ==="
rg -nP --type=py 'get_eigenvalues|get_Hamiltonian' dptb/ --max-count=20

# Check EigHamLoss implementation to see if it validates configs
echo ""
echo "=== EigHamLoss implementation ==="
sed -n '300,350p' dptb/nnops/loss.py

# Check if there's any assertion about valid loss method + data flag combinations
echo ""
echo "=== Searching for combined validation logic ==="
rg -nP --type=py -A10 'def.*loss.*method|check.*loss' dptb/utils/

Repository: deepmodeling/DeePTB

Length of output: 20108


🏁 Script executed:

#!/bin/bash

# Check trainer.py for loss method validation
echo "=== Checking trainer.py around line 40 ==="
sed -n '35,80p' dptb/nnops/trainer.py

# Check if there's any validation that ties loss method to data config
echo ""
echo "=== Searching for loss method initialization/validation ==="
rg -nP --type=py -B5 -A5 'loss_method.*=|Loss\.register' dptb/nnops/trainer.py | head -50

# Check if config_check.py has any other validations we missed
echo ""
echo "=== Full config_check.py validation logic ==="
cat dptb/utils/config_check.py | head -80

Repository: deepmodeling/DeePTB

Length of output: 6483


🏁 Script executed:

#!/bin/bash

# Check the full task assignment logic in trainer.py
echo "=== Full task assignment logic in trainer.py ==="
sed -n '38,50p' dptb/nnops/trainer.py

# Check how self.task is used in the trainer
echo ""
echo "=== How self.task is used in trainer ==="
rg -nP --type=py 'self\.task' dptb/nnops/trainer.py

# Check if there's any post-validation that would catch this gap
echo ""
echo "=== Checking for loss method to data flag validation ==="
rg -nP --type=py -C3 'eig_ham.*method|method.*eig_ham' dptb/

Repository: deepmodeling/DeePTB

Length of output: 811


Add validation to enforce eig_ham loss method when both get_Hamiltonian and get_eigenvalues are True.

Removing this guard creates a critical validation gap. The current validation in config_check.py only handles mutual-exclusivity (lines 34-38) but lacks a case for when both flags are True. This allows invalid configurations where both flags are enabled with an incompatible loss method (e.g., "hamil_abs" or "eigvals").

Additionally, the task assignment logic in trainer.py (lines 38-45) doesn't account for the combined case:

if self.train_datasets.get_Hamiltonian:
    self.task = "hamiltonians"
elif self.train_datasets.get_DM:
    self.task = "DM"
else:
    self.task = "eigenvalues"

When both flags are True, this unconditionally assigns task="hamiltonians", bypassing proper handling of the new eig_ham loss workflow.

Add an assertion in config_check.py to enforce that when both flags are True, the loss method must be "eig_ham".

🤖 Prompt for AI Agents
In @dptb/utils/config_check.py around lines 40-41, Add a validation in
config_check.py that if train_data_config.get("get_Hamiltonian") and
train_data_config.get("get_eigenvalues") are both True then assert or raise
unless the configured loss method equals "eig_ham" (i.e., enforce loss_method ==
"eig_ham"); also update the task-selection logic in trainer.py (the block that
checks self.train_datasets.get_Hamiltonian / get_DM / else) to detect the
combined case (both get_Hamiltonian and get_eigenvalues True) and set self.task
= "eig_ham" so the combined workflow is routed correctly.


#if jdata["data_options"].get("validation"):

Expand Down