Skip to content

Commit ca62f90

Browse files
committed
Add WavefunctionConvergenceTarget for orbital coefficient convergence
- Add `WavefunctionConvergenceTarget` class to `general.py` for tracking wavefunction coefficient convergence in SCF workflows - Add `delta_wavefunction_rms` quantity to `SCFSteps` schema in `outputs.py` to store RMS changes of wavefunction coefficients - Add comprehensive test suite covering edge cases: zero convergence, missing data, single iteration, NaN/Inf handling, negative values, boundary conditions, and array vs scalar data - Use `positive_float()` for threshold validation (x ≥ 0) - Set `_convergence_property_path` to `scf_steps.delta_wavefunction_rms` for automatic resolution All 51 convergence target tests pass.
1 parent e965f9b commit ca62f90

File tree

3 files changed

+167
-0
lines changed

3 files changed

+167
-0
lines changed

src/nomad_simulations/schema_packages/outputs.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,16 @@ class SCFSteps(ArchiveSection):
8383
""",
8484
)
8585

86+
delta_wavefunction_rms = Quantity(
87+
shape=['*'],
88+
type=float,
89+
unit='dimensionless',
90+
description="""
91+
Root mean square of change of wavefunction coefficients at each SCF step.
92+
Dimensionless quantity representing convergence of orbital coefficients.
93+
""",
94+
)
95+
8696
delta_force_abs = Quantity(
8797
shape=['*'],
8898
type=float,

src/nomad_simulations/schema_packages/workflow/general.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,31 @@ class ChargeConvergenceTarget(WorkflowConvergenceTarget):
355355
_convergence_property_path = 'scf_steps.delta_density_rms'
356356

357357

358+
class WavefunctionConvergenceTarget(WorkflowConvergenceTarget):
359+
"""
360+
Convergence target for wavefunction coefficients in SCF workflows.
361+
362+
Measures convergence of wavefunction or orbital coefficients between
363+
successive SCF iterations. This is less commonly reported than density
364+
convergence but is available in some quantum chemistry codes.
365+
366+
Note: This is distinct from density convergence. Some codes report both
367+
wavefunction and density convergence independently.
368+
"""
369+
370+
threshold = Quantity(
371+
type=positive_float(),
372+
unit='dimensionless',
373+
description="""
374+
Wavefunction convergence threshold. Must be non-negative.
375+
Typically dimensionless as it represents changes in wavefunction coefficients.
376+
""",
377+
)
378+
379+
# Property path for automatic extraction
380+
_convergence_property_path = 'scf_steps.delta_wavefunction_rms'
381+
382+
358383
class SimulationWorkflowModel(ArchiveSection):
359384
"""
360385
Base class for simulation workflow model sub-section definition.

tests/workflow/test_convergence_targets.py

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
PotentialConvergenceTarget,
2222
SimulationWorkflow,
2323
SimulationWorkflowMethod,
24+
WavefunctionConvergenceTarget,
2425
)
2526

2627

@@ -48,6 +49,12 @@ def charge_target():
4849
return ChargeConvergenceTarget()
4950

5051

52+
@pytest.fixture(scope='function')
53+
def wavefunction_target():
54+
"""Fixture providing a WavefunctionConvergenceTarget instance."""
55+
return WavefunctionConvergenceTarget()
56+
57+
5158
class TestEnergyConvergenceTarget:
5259
"""Test the EnergyConvergenceTarget class."""
5360

@@ -360,6 +367,131 @@ def test_charge_missing_data(self, archive, logger, charge_target):
360367
assert is_reached is None
361368

362369

370+
class TestWavefunctionConvergenceTarget:
371+
"""Test the WavefunctionConvergenceTarget class."""
372+
373+
@pytest.mark.parametrize(
374+
'threshold, threshold_type, wf_values, expected_reached',
375+
[
376+
# Absolute convergence - converged
377+
(1e-8, 'absolute', [1e-10, 1e-11, 5e-12], True),
378+
# Absolute convergence - not converged
379+
(1e-8, 'absolute', [1e-5, 1e-6, 1e-7], False),
380+
# Zero/perfect convergence
381+
(1e-8, 'absolute', [0.0, 0.0, 0.0], True),
382+
# RMS convergence with array data - converged
383+
(1e-8, 'rms', [[1e-10, 2e-10], [5e-11, 6e-11]], True),
384+
# RMS convergence with array data - not converged
385+
(1e-8, 'rms', [[1e-5, 2e-5], [8e-6, 9e-6]], False),
386+
# All zeros in array
387+
(1e-8, 'rms', [[0.0, 0.0], [0.0, 0.0]], True),
388+
# Maximum convergence - not converged
389+
(1e-8, 'maximum', [[1e-7, 1e-10], [8e-7, 2e-10]], False),
390+
# Values right at threshold boundary - not converged
391+
(1e-8, 'absolute', [0.0, 1.1e-8, 2.2e-8], False),
392+
# Very small values near machine epsilon
393+
(1e-16, 'absolute', [1e-17, 5e-18, 1e-18], True),
394+
],
395+
)
396+
def test_wavefunction_convergence(
397+
self,
398+
threshold: float,
399+
threshold_type: str,
400+
wf_values: list,
401+
expected_reached: bool,
402+
archive,
403+
logger,
404+
wavefunction_target,
405+
):
406+
"""
407+
Test wavefunction convergence with different thresholds and data types.
408+
409+
Args:
410+
threshold: Convergence threshold (dimensionless).
411+
threshold_type: Type of convergence check.
412+
wf_values: List of wavefunction values (scalar or array).
413+
expected_reached: Expected value of is_reached flag.
414+
"""
415+
wavefunction_target.threshold = threshold
416+
wavefunction_target.threshold_type = threshold_type
417+
418+
scf_step = SCFSteps()
419+
if len(wf_values) > 1:
420+
# Convert to deltas based on data type
421+
if isinstance(wf_values[0], list):
422+
# Array data - compute element-wise deltas
423+
deltas = []
424+
for i in range(1, len(wf_values)):
425+
delta = np.abs(np.array(wf_values[i]) - np.array(wf_values[i - 1]))
426+
deltas.append(delta)
427+
scf_step.delta_wavefunction_rms = deltas
428+
else:
429+
# Scalar data - compute simple deltas
430+
deltas = [
431+
abs(wf_values[i] - wf_values[i - 1])
432+
for i in range(1, len(wf_values))
433+
]
434+
scf_step.delta_wavefunction_rms = np.array(deltas)
435+
436+
archive.data.outputs = [Outputs(scf_steps=scf_step)]
437+
is_reached = wavefunction_target.normalize(archive, logger)
438+
assert is_reached == expected_reached
439+
440+
def test_wavefunction_missing_data(self, archive, logger, wavefunction_target):
441+
"""Test wavefunction convergence with missing data."""
442+
wavefunction_target.threshold = 1e-8
443+
wavefunction_target.threshold_type = 'absolute'
444+
445+
# No outputs at all
446+
is_reached = wavefunction_target.normalize(archive, logger)
447+
assert is_reached is None
448+
449+
# Empty scf_steps
450+
archive.data.outputs = [Outputs(scf_steps=SCFSteps())]
451+
is_reached = wavefunction_target.normalize(archive, logger)
452+
assert is_reached is None
453+
454+
def test_wavefunction_single_iteration(self, archive, logger, wavefunction_target):
455+
"""Test with only one iteration (cannot compute convergence)."""
456+
wavefunction_target.threshold = 1e-8
457+
wavefunction_target.threshold_type = 'absolute'
458+
459+
# Single value - no delta can be computed
460+
scf_step = SCFSteps()
461+
# With only one iteration, there should be no delta_wavefunction_rms
462+
# or it should be empty
463+
archive.data.outputs = [Outputs(scf_steps=scf_step)]
464+
is_reached = wavefunction_target.normalize(archive, logger)
465+
assert is_reached is None
466+
467+
def test_wavefunction_nan_values(self, archive, logger, wavefunction_target):
468+
"""Test handling of NaN values in wavefunction data."""
469+
wavefunction_target.threshold = 1e-8
470+
wavefunction_target.threshold_type = 'absolute'
471+
472+
scf_step = SCFSteps()
473+
scf_step.delta_wavefunction_rms = np.array([np.nan, 1e-10])
474+
archive.data.outputs = [Outputs(scf_steps=scf_step)]
475+
476+
# Should handle gracefully without crashing
477+
wavefunction_target.normalize(archive, logger)
478+
# Test passes if no exception is raised
479+
480+
def test_wavefunction_negative_values(self, archive, logger, wavefunction_target):
481+
"""Test that negative residuals are treated as absolute values."""
482+
wavefunction_target.threshold = 1e-8
483+
wavefunction_target.threshold_type = 'absolute'
484+
485+
# Negative values should be treated as absolute
486+
scf_step = SCFSteps()
487+
scf_step.delta_wavefunction_rms = np.array([-1e-10, -5e-11])
488+
archive.data.outputs = [Outputs(scf_steps=scf_step)]
489+
490+
is_reached = wavefunction_target.normalize(archive, logger)
491+
# Should converge since abs values are below threshold
492+
assert is_reached is True
493+
494+
363495
class TestConvergenceHelperMethods:
364496
"""Test the base class helper methods for convergence checking."""
365497

0 commit comments

Comments
 (0)