-
Notifications
You must be signed in to change notification settings - Fork 26
feat: basis transfer from small basis to larger ones by finetuning on the the eigenvalues. #305
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
align unit
📝 WalkthroughWalkthroughThese changes broaden input type flexibility in loss configuration options and enhance task differentiation logic in the trainer's dataset initialization based on Hamiltonian and eigenvalue request flags. Changes
Estimated code review effort🎯 2 (Simple) | ⏱️ ~10 minutes 🚥 Pre-merge checks | ✅ 1 | ❌ 2❌ Failed checks (1 warning, 1 inconclusive)
✅ Passed checks (1 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
🧪 Generate unit tests (beta)
📜 Recent review detailsConfiguration used: Organization UI Review profile: CHILL Plan: Pro 📒 Files selected for processing (2)
🧰 Additional context used🪛 Ruff (0.14.10)dptb/nnops/trainer.py50-50: Avoid specifying long messages outside the exception class (TRY003) ⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
🔇 Additional comments (2)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 6
Fix all issues with AI Agents 🤖
In @dptb/data/AtomicData.py:
- Line 110: AtomicDataDict.EIGENVECTOR_KEY is only in _DEFAULT_GRAPH_FIELDS but
not in _DEFAULT_NESTED_FIELDS, so eigenvector tensors won’t be converted to
nested tensors by _process_dict like ENERGY_EIGENVALUE_KEY is; add
AtomicDataDict.EIGENVECTOR_KEY to the _DEFAULT_NESTED_FIELDS list (the same
place ENERGY_EIGENVALUE_KEY is listed) so _process_dict will treat eigenvectors
as nested fields and avoid shape/dtype mismatches when batching multi-frame
data.
In @dptb/nnops/loss.py:
- Around line 315-316: The EigHamLoss coefficient logic can divide by zero and
currently nulls eigloss by default; change the defaults to include both terms
(e.g., coeff_ham: float=0.5, coeff_ovp: float=0.5) and add a small lower-bound
guard (eps, e.g., 1e-8) when normalizing weights so expressions like
(self.coeff_ovp / self.coeff_ham) cannot throw; update the
normalization/combination code that computes ham_loss and eigloss (references:
EigHamLoss, self.coeff_ham, self.coeff_ovp, ham_loss, eigloss) to use guarded
normalization (if self.coeff_ham < eps handle as special case or use eps) and
ensure both losses contribute according to the normalized weights.
- Around line 352-353: The conditional that sets batch when self.onsite_shift is
true uses torch.zeros(...) which defaults to CPU and can cause device/dtype
mismatches; change the fallback to create the tensor on the same device/dtype as
the positions tensor used earlier (access via
data[AtomicDataDict.POSITIONS_KEY]) or use a device-aware constructor (e.g.,
torch.zeros with device=pos.device and dtype=pos.dtype or torch.zeros_like) so
the batch tensor lives on the same device and has compatible dtype before any
comparisons/indexing.
In @dptb/utils/argcheck.py:
- Around line 840-843: Update the docstring for the Argument named coeff_ovp
inside the eig_ham list so it correctly describes the overlap penalty rather
than the Hamiltonian penalty; replace the current text "The coefficient of the
hamiltonian penalty. Default: 1" with something like "The coefficient of the
overlap penalty. Default: 1" so coeff_ham remains the Hamiltonian coefficient
and coeff_ovp clearly documents the overlap coefficient.
In @dptb/utils/config_check.py:
- Around line 40-41: Add a validation in config_check.py that if
train_data_config.get("get_Hamiltonian") and
train_data_config.get("get_eigenvalues") are both True then assert or raise
unless the configured loss method equals "eig_ham" (i.e., enforce loss_method ==
"eig_ham"); also update the task-selection logic in trainer.py (the block that
checks self.train_datasets.get_Hamiltonian / get_DM / else) to detect the
combined case (both get_Hamiltonian and get_eigenvalues True) and set self.task
= "eig_ham" so the combined workflow is routed correctly.
🧹 Nitpick comments (3)
dptb/nn/energy.py (1)
99-103: Consider using explicitOptionaltype hints and avoiding mutable defaults.Static analysis flags: (1) PEP 484 prohibits implicit
Optionalfor parameters withNonedefaults, (2)torch.device("cpu")in argument defaults is evaluated once at function definition time.🔎 Proposed fix
- s_edge_field: str = None, - s_node_field: str = None, - s_out_field: str = None, + s_edge_field: Optional[str] = None, + s_node_field: Optional[str] = None, + s_out_field: Optional[str] = None, dtype: Union[str, torch.dtype] = torch.float32, - device: Union[str, torch.device] = torch.device("cpu")): + device: Union[str, torch.device] = "cpu"):dptb/nnops/loss.py (2)
352-372: Consider extracting duplicatedonsite_shiftlogic.The
onsite_shifthandling (lines 352-372) is nearly identical to the implementation inHamilLossAbs(lines 438-458) and other loss classes. Consider extracting this into a shared helper method to reduce duplication and maintenance burden.
304-304: Address static analysis hints for type annotations and unused kwargs.Per static analysis:
- Lines 304, 313: Use explicit
Optional[T]instead of implicitT = None- Line 309: Avoid
torch.device("cpu")in defaults- Line 317:
**kwargsis unused🔎 Proposed fix
def __init__( self, - basis: Dict[str, Union[str, list]]=None, + basis: Optional[Dict[str, Union[str, list]]]=None, idp: Union[OrbitalMapper, None]=None, overlap: bool=False, onsite_shift: bool=False, dtype: Union[str, torch.dtype] = torch.float32, - device: Union[str, torch.device] = torch.device("cpu"), + device: Union[str, torch.device] = "cpu", diff_on: bool=False, eout_weight: float=0.01, diff_weight: float=0.01, - diff_valence: dict=None, + diff_valence: Optional[dict]=None, spin_deg: int = 2, coeff_ham: float=1., coeff_ovp: float=1., - **kwargs, ):Note: If
**kwargsis needed for compatibility with the loss registry, keep it but add a comment explaining its purpose.Also applies to: 309-309, 313-313, 317-317
📜 Review details
Configuration used: Organization UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (7)
dptb/data/AtomicData.pydptb/data/_keys.pydptb/nn/__init__.pydptb/nn/energy.pydptb/nnops/loss.pydptb/utils/argcheck.pydptb/utils/config_check.py
🧰 Additional context used
🧬 Code graph analysis (3)
dptb/nn/energy.py (2)
dptb/data/transforms.py (1)
OrbitalMapper(395-882)dptb/nn/hr2hk.py (2)
HR2HK(10-189)forward(48-189)
dptb/nn/__init__.py (1)
dptb/nn/energy.py (1)
Eigh(90-172)
dptb/nnops/loss.py (1)
dptb/data/transforms.py (1)
OrbitalMapper(395-882)
🪛 Ruff (0.14.10)
dptb/nn/energy.py
99-99: PEP 484 prohibits implicit Optional
Convert to T | None
(RUF013)
100-100: PEP 484 prohibits implicit Optional
Convert to T | None
(RUF013)
101-101: PEP 484 prohibits implicit Optional
Convert to T | None
(RUF013)
103-103: Do not perform function call torch.device in argument defaults; instead, perform the call within the function, or read the default from a module-level singleton variable
(B008)
dptb/nnops/loss.py
304-304: PEP 484 prohibits implicit Optional
Convert to T | None
(RUF013)
309-309: Do not perform function call torch.device in argument defaults; instead, perform the call within the function, or read the default from a module-level singleton variable
(B008)
313-313: PEP 484 prohibits implicit Optional
Convert to T | None
(RUF013)
317-317: Unused method argument: kwargs
(ARG002)
363-363: Consider [0, *slices] instead of concatenation
Replace with [0, *slices]
(RUF005)
🔇 Additional comments (6)
dptb/nn/__init__.py (1)
7-7: LGTM!The import and export of the new
Eighclass are correctly wired, making the eigendecomposition module publicly accessible fromdptb.nn.Also applies to: 15-15
dptb/data/_keys.py (1)
45-45: LGTM!The new constant is appropriately named and positioned alongside related eigenvalue keys.
dptb/utils/argcheck.py (1)
856-856: LGTM!The new
eig_hamloss variant correctly composes Hamiltonian, eigenvalue, and Hamiltonian-penalty sub-fields, aligning with the basis-transfer workflow introduced in this PR.dptb/nn/energy.py (2)
83-88: LGTM!Correctly restores the
KPOINT_KEYto its original nested/non-nested representation after processing.
164-165: Inconsistent tensor structure between eigenvalues and eigenvectors.
eigval_fieldis wrapped as a nested tensor whileeigvec_fieldis a regular tensor. This inconsistency may cause issues downstream if consumers expect both to have the same structure. Please verify this is intentional.dptb/nnops/loss.py (1)
300-398: Overall structure ofEigHamLosslooks reasonable, pending fixes above.The class correctly:
- Registers with the loss registry under "eig_ham"
- Initializes both L1 and MSE loss functions with the combined loss formula
- Delegates eigenvalue loss computation to
EigLoss- Handles overlap terms conditionally
After addressing the coefficient logic issues, this will provide a useful combined loss for basis transfer training.
| coeff_ham: float=1., | ||
| coeff_ovp: float=1., |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Potential division by zero and default coeff_ham=1.0 makes eigloss unused.
Two related issues with the coefficient logic:
-
Division by zero (Line 391): If
coeff_ham=0, the expression(self.coeff_ovp / self.coeff_ham)will raiseZeroDivisionError. -
Default behavior ignores eigloss (Line 397): With
coeff_ham=1.0(default), the final loss becomes1.0 * ham_loss + 0.0 * eigloss, effectively ignoring the eigenvalue loss entirely. This seems unintentional for a loss class namedEigHamLoss.
Consider either:
- Changing the default to a value that includes both losses (e.g.,
coeff_ham=0.5) - Adding a lower bound check to prevent division by zero
- Clarifying the intended behavior in documentation
🔎 Proposed fix (add bounds check and adjust default)
- coeff_ham: float=1.,
+ coeff_ham: float=0.5,
coeff_ovp: float=1.,
**kwargs,
):
super(EigHamLoss, self).__init__()
self.loss1 = nn.L1Loss()
self.loss2 = nn.MSELoss()
self.overlap = overlap
self.device = device
self.onsite_shift = onsite_shift
self.coeff_ham = coeff_ham
- assert self.coeff_ham <= 1.
+ assert 0. < self.coeff_ham <= 1., "coeff_ham must be in (0, 1]"
self.coeff_ovp = coeff_ovpAlso applies to: 325-327, 391-391, 397-397
🤖 Prompt for AI Agents
In @dptb/nnops/loss.py around lines 315-316, The EigHamLoss coefficient logic
can divide by zero and currently nulls eigloss by default; change the defaults
to include both terms (e.g., coeff_ham: float=0.5, coeff_ovp: float=0.5) and add
a small lower-bound guard (eps, e.g., 1e-8) when normalizing weights so
expressions like (self.coeff_ovp / self.coeff_ham) cannot throw; update the
normalization/combination code that computes ham_loss and eigloss (references:
EigHamLoss, self.coeff_ham, self.coeff_ovp, ham_loss, eigloss) to use guarded
normalization (if self.coeff_ham < eps handle as special case or use eps) and
ensure both losses contribute according to the normalized weights.
| if self.onsite_shift: | ||
| batch = data.get("batch", torch.zeros(data[AtomicDataDict.POSITIONS_KEY].shape[0])) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tensor created on CPU may cause device mismatch.
torch.zeros(...) creates a tensor on CPU by default. If data tensors reside on GPU, this may cause device mismatch errors during comparison or indexing operations.
🔎 Proposed fix
if self.onsite_shift:
- batch = data.get("batch", torch.zeros(data[AtomicDataDict.POSITIONS_KEY].shape[0]))
+ batch = data.get("batch", torch.zeros(data[AtomicDataDict.POSITIONS_KEY].shape[0], device=self.device, dtype=torch.long))🤖 Prompt for AI Agents
In @dptb/nnops/loss.py around lines 352-353, The conditional that sets batch
when self.onsite_shift is true uses torch.zeros(...) which defaults to CPU and
can cause device/dtype mismatches; change the fallback to create the tensor on
the same device/dtype as the positions tensor used earlier (access via
data[AtomicDataDict.POSITIONS_KEY]) or use a device-aware constructor (e.g.,
torch.zeros with device=pos.device and dtype=pos.dtype or torch.zeros_like) so
the batch tensor lives on the same device and has compatible dtype before any
comparisons/indexing.
| # if train_data_config.get("get_Hamiltonian") and train_data_config.get("get_eigenvalues"): | ||
| # raise RuntimeError("The train data set should not have both get_Hamiltonian and get_eigenvalues set to True.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Description: Check for code that assumes mutual exclusivity of get_Hamiltonian and get_eigenvalues
# Search for conditional logic that checks both flags
echo "=== Checking for conditional logic with both flags ==="
rg -nP --type=py -C3 'get_Hamiltonian.*get_eigenvalues|get_eigenvalues.*get_Hamiltonian'
echo ""
echo "=== Searching for eig_ham and EigHamLoss references ==="
# Search for loss method registration and usage
rg -nP --type=py -C3 'eig_ham|EigHamLoss'Repository: deepmodeling/DeePTB
Length of output: 3591
🏁 Script executed:
#!/bin/bash
# Check the full config_check.py file around the guard area
echo "=== Full context of config_check.py lines 30-50 ==="
sed -n '30,50p' dptb/utils/config_check.py
# Check if there's any other validation for loss method selection
echo ""
echo "=== Searching for loss method validation logic ==="
rg -nP --type=py -C5 'loss_options.*method|train_options.*method' dptb/utils/config_check.py
# Check argcheck.py loss method validation structure
echo ""
echo "=== Loss method validation in argcheck.py ==="
rg -nP --type=py -B5 -A15 'eig_ham.*='Repository: deepmodeling/DeePTB
Length of output: 5056
🏁 Script executed:
#!/bin/bash
# Check if there's validation that ensures loss method matches data config
echo "=== Checking for loss method validation elsewhere ==="
rg -nP --type=py 'get_eigenvalues|get_Hamiltonian' dptb/ --max-count=20
# Check EigHamLoss implementation to see if it validates configs
echo ""
echo "=== EigHamLoss implementation ==="
sed -n '300,350p' dptb/nnops/loss.py
# Check if there's any assertion about valid loss method + data flag combinations
echo ""
echo "=== Searching for combined validation logic ==="
rg -nP --type=py -A10 'def.*loss.*method|check.*loss' dptb/utils/Repository: deepmodeling/DeePTB
Length of output: 20108
🏁 Script executed:
#!/bin/bash
# Check trainer.py for loss method validation
echo "=== Checking trainer.py around line 40 ==="
sed -n '35,80p' dptb/nnops/trainer.py
# Check if there's any validation that ties loss method to data config
echo ""
echo "=== Searching for loss method initialization/validation ==="
rg -nP --type=py -B5 -A5 'loss_method.*=|Loss\.register' dptb/nnops/trainer.py | head -50
# Check if config_check.py has any other validations we missed
echo ""
echo "=== Full config_check.py validation logic ==="
cat dptb/utils/config_check.py | head -80Repository: deepmodeling/DeePTB
Length of output: 6483
🏁 Script executed:
#!/bin/bash
# Check the full task assignment logic in trainer.py
echo "=== Full task assignment logic in trainer.py ==="
sed -n '38,50p' dptb/nnops/trainer.py
# Check how self.task is used in the trainer
echo ""
echo "=== How self.task is used in trainer ==="
rg -nP --type=py 'self\.task' dptb/nnops/trainer.py
# Check if there's any post-validation that would catch this gap
echo ""
echo "=== Checking for loss method to data flag validation ==="
rg -nP --type=py -C3 'eig_ham.*method|method.*eig_ham' dptb/Repository: deepmodeling/DeePTB
Length of output: 811
Add validation to enforce eig_ham loss method when both get_Hamiltonian and get_eigenvalues are True.
Removing this guard creates a critical validation gap. The current validation in config_check.py only handles mutual-exclusivity (lines 34-38) but lacks a case for when both flags are True. This allows invalid configurations where both flags are enabled with an incompatible loss method (e.g., "hamil_abs" or "eigvals").
Additionally, the task assignment logic in trainer.py (lines 38-45) doesn't account for the combined case:
if self.train_datasets.get_Hamiltonian:
self.task = "hamiltonians"
elif self.train_datasets.get_DM:
self.task = "DM"
else:
self.task = "eigenvalues"When both flags are True, this unconditionally assigns task="hamiltonians", bypassing proper handling of the new eig_ham loss workflow.
Add an assertion in config_check.py to enforce that when both flags are True, the loss method must be "eig_ham".
🤖 Prompt for AI Agents
In @dptb/utils/config_check.py around lines 40-41, Add a validation in
config_check.py that if train_data_config.get("get_Hamiltonian") and
train_data_config.get("get_eigenvalues") are both True then assert or raise
unless the configured loss method equals "eig_ham" (i.e., enforce loss_method ==
"eig_ham"); also update the task-selection logic in trainer.py (the block that
checks self.train_datasets.get_Hamiltonian / get_DM / else) to detect the
combined case (both get_Hamiltonian and get_eigenvalues True) and set self.task
= "eig_ham" so the combined workflow is routed correctly.
**Changes:** - Add new "hamil_eigvals" task type for training with both Hamiltonian and eigenvalues - Improve task detection logic with explicit checks for all data types - Add RuntimeError when no valid data type flags are set in train dataset - Fix documentation typo: correct "hamiltonian penalty" to "overlap penalty" for coeff_ovp parameter **Why:** - Enables training on combined Hamiltonian and eigenvalues datasets - Provides clearer error messages when dataset configuration is invalid - Ensures at least one data type is specified for training - Corrects misleading parameter documentation
Allow both integer and float types for `coeff_ham` and `coeff_ovp` arguments in loss options. This provides more flexibility for users who may naturally pass integer values (e.g., 1 instead of 1.0) while maintaining backward compatibility with existing float inputs. Changed argument type specification from `float` to `[int, float]` for: - coeff_ham: coefficient of the hamiltonian penalty - coeff_ovp: coefficient of the overlap penalty
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (1)
dptb/utils/argcheck.py (1)
840-857: Renameeig_hamlist toeig_ham_fieldsfor clarity, and ensure the new method is documented indoc_method.The variable name
eig_hamcreates unintended shadowing with the method tag"eig_ham"in the Variant choices, making the code harder to follow. The proposed rename toeig_ham_fieldsdisambiguates the two uses. Additionally, ifdoc_methoddoesn't currently include theeig_hammethod, it should be added so the method is discoverable to users.Note: dargs automatically applies defaults for nested sub_fields (coeff_ham and coeff_ovp) when the
eig_hamvariant choice is selected and normalized withdo_default=True, so no special handling is needed in the configuration schema.
📜 Review details
Configuration used: Organization UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
dptb/nnops/trainer.pydptb/utils/argcheck.py
🧰 Additional context used
🪛 Ruff (0.14.10)
dptb/nnops/trainer.py
50-50: Avoid specifying long messages outside the exception class
(TRY003)
🔇 Additional comments (2)
dptb/utils/argcheck.py (1)
1759-1759: Non-functional change (OK).dptb/nnops/trainer.py (1)
40-50: Unable to rewrite review comment due to repository access failure. Manual verification is required to:
- Confirm actual indentation in dptb/nnops/trainer.py at line 42
- Search downstream code for all
self.taskconsumers and verify they handle"hamil_eigvals"- Validate that no code path is unaware of this new task type
This PR record the code for performing basis transfer from small basis to larger ones by finetuning on the the eigenvalues.
Summary by CodeRabbit
Release Notes
New Features
Bug Fixes
✏️ Tip: You can customize this high-level summary in your review settings.