Skip to content

Commit e965f9b

Browse files
committed
Add non-negative validation to convergence target thresholds
Change threshold type in all convergence target classes from `np.float64` to `positive_float()` to enforce schema-level validation. This ensures thresholds are non-negative (x ≥ 0), allowing zero thresholds but rejecting semantically invalid negative values. Updated classes: - `EnergyConvergenceTarget` - `ForceConvergenceTarget` - `PotentialConvergenceTarget` - `ChargeConvergenceTarget` Add test to verify zero thresholds are accepted. Remove test for negative thresholds as these are now prevented at the schema level.
1 parent c20742d commit e965f9b

File tree

2 files changed

+10
-23
lines changed

2 files changed

+10
-23
lines changed

src/nomad_simulations/schema_packages/workflow/general.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from structlog.stdlib import BoundLogger
99

1010
from nomad_simulations.schema_packages.common import SimulationTime
11+
from nomad_simulations.schema_packages.data_types import positive_float
1112
from nomad_simulations.schema_packages.model_method import ModelMethod
1213
from nomad_simulations.schema_packages.model_system import ModelSystem
1314
from nomad_simulations.schema_packages.outputs import Outputs
@@ -244,10 +245,10 @@ class EnergyConvergenceTarget(WorkflowConvergenceTarget):
244245
"""
245246

246247
threshold = Quantity(
247-
type=np.float64,
248+
type=positive_float(),
248249
unit='joule',
249250
description="""
250-
Energy convergence threshold.
251+
Energy convergence threshold. Must be non-negative.
251252
""",
252253
)
253254

@@ -262,10 +263,10 @@ class ForceConvergenceTarget(WorkflowConvergenceTarget):
262263
"""
263264

264265
threshold = Quantity(
265-
type=np.float64,
266+
type=positive_float(),
266267
unit='newton',
267268
description="""
268-
Force convergence threshold.
269+
Force convergence threshold. Must be non-negative.
269270
""",
270271
)
271272

@@ -325,10 +326,10 @@ class PotentialConvergenceTarget(WorkflowConvergenceTarget):
325326
"""
326327

327328
threshold = Quantity(
328-
type=np.float64,
329+
type=positive_float(),
329330
unit='joule',
330331
description="""
331-
Potential convergence threshold.
332+
Potential convergence threshold. Must be non-negative.
332333
""",
333334
)
334335

@@ -343,10 +344,10 @@ class ChargeConvergenceTarget(WorkflowConvergenceTarget):
343344
"""
344345

345346
threshold = Quantity(
346-
type=np.float64,
347+
type=positive_float(),
347348
unit='coulomb',
348349
description="""
349-
Charge/density convergence threshold.
350+
Charge/density convergence threshold. Must be non-negative.
350351
""",
351352
)
352353

tests/workflow/test_convergence_targets.py

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -569,7 +569,7 @@ class TestEdgeCases:
569569
"""Test edge cases and error handling."""
570570

571571
def test_zero_threshold(self, archive, logger, energy_target):
572-
"""Test with zero threshold (only exact zero should converge)."""
572+
"""Test that zero threshold is accepted (non-negative validation)."""
573573
energy_target.threshold = 0.0
574574
energy_target.threshold_type = 'absolute'
575575

@@ -582,20 +582,6 @@ def test_zero_threshold(self, archive, logger, energy_target):
582582
# With <= comparison, exact zero matches zero threshold
583583
assert is_reached is True
584584

585-
def test_negative_threshold(self, archive, logger, energy_target):
586-
"""Test that negative threshold is handled (should use absolute value)."""
587-
energy_target.threshold = -1e-6 # Negative threshold
588-
energy_target.threshold_type = 'absolute'
589-
590-
scf_step = SCFSteps()
591-
scf_step.delta_energies_total = np.array([5e-7]) * ureg.joule
592-
593-
archive.data.outputs = [Outputs(scf_steps=scf_step)]
594-
is_reached = energy_target.normalize(archive, logger)
595-
596-
# Should still check convergence (implementation dependent)
597-
assert is_reached is not None
598-
599585
def test_very_large_values(self, archive, logger, force_target):
600586
"""Test with very large force values."""
601587
force_target.threshold = 1e10

0 commit comments

Comments
 (0)