|
21 | 21 | PotentialConvergenceTarget, |
22 | 22 | SimulationWorkflow, |
23 | 23 | SimulationWorkflowMethod, |
| 24 | + WavefunctionConvergenceTarget, |
24 | 25 | ) |
25 | 26 |
|
26 | 27 |
|
@@ -48,6 +49,12 @@ def charge_target(): |
48 | 49 | return ChargeConvergenceTarget() |
49 | 50 |
|
50 | 51 |
|
| 52 | +@pytest.fixture(scope='function') |
| 53 | +def wavefunction_target(): |
| 54 | + """Fixture providing a WavefunctionConvergenceTarget instance.""" |
| 55 | + return WavefunctionConvergenceTarget() |
| 56 | + |
| 57 | + |
51 | 58 | class TestEnergyConvergenceTarget: |
52 | 59 | """Test the EnergyConvergenceTarget class.""" |
53 | 60 |
|
@@ -360,6 +367,131 @@ def test_charge_missing_data(self, archive, logger, charge_target): |
360 | 367 | assert is_reached is None |
361 | 368 |
|
362 | 369 |
|
| 370 | +class TestWavefunctionConvergenceTarget: |
| 371 | + """Test the WavefunctionConvergenceTarget class.""" |
| 372 | + |
| 373 | + @pytest.mark.parametrize( |
| 374 | + 'threshold, threshold_type, wf_values, expected_reached', |
| 375 | + [ |
| 376 | + # Absolute convergence - converged |
| 377 | + (1e-8, 'absolute', [1e-10, 1e-11, 5e-12], True), |
| 378 | + # Absolute convergence - not converged |
| 379 | + (1e-8, 'absolute', [1e-5, 1e-6, 1e-7], False), |
| 380 | + # Zero/perfect convergence |
| 381 | + (1e-8, 'absolute', [0.0, 0.0, 0.0], True), |
| 382 | + # RMS convergence with array data - converged |
| 383 | + (1e-8, 'rms', [[1e-10, 2e-10], [5e-11, 6e-11]], True), |
| 384 | + # RMS convergence with array data - not converged |
| 385 | + (1e-8, 'rms', [[1e-5, 2e-5], [8e-6, 9e-6]], False), |
| 386 | + # All zeros in array |
| 387 | + (1e-8, 'rms', [[0.0, 0.0], [0.0, 0.0]], True), |
| 388 | + # Maximum convergence - not converged |
| 389 | + (1e-8, 'maximum', [[1e-7, 1e-10], [8e-7, 2e-10]], False), |
| 390 | + # Values right at threshold boundary - not converged |
| 391 | + (1e-8, 'absolute', [0.0, 1.1e-8, 2.2e-8], False), |
| 392 | + # Very small values near machine epsilon |
| 393 | + (1e-16, 'absolute', [1e-17, 5e-18, 1e-18], True), |
| 394 | + ], |
| 395 | + ) |
| 396 | + def test_wavefunction_convergence( |
| 397 | + self, |
| 398 | + threshold: float, |
| 399 | + threshold_type: str, |
| 400 | + wf_values: list, |
| 401 | + expected_reached: bool, |
| 402 | + archive, |
| 403 | + logger, |
| 404 | + wavefunction_target, |
| 405 | + ): |
| 406 | + """ |
| 407 | + Test wavefunction convergence with different thresholds and data types. |
| 408 | +
|
| 409 | + Args: |
| 410 | + threshold: Convergence threshold (dimensionless). |
| 411 | + threshold_type: Type of convergence check. |
| 412 | + wf_values: List of wavefunction values (scalar or array). |
| 413 | + expected_reached: Expected value of is_reached flag. |
| 414 | + """ |
| 415 | + wavefunction_target.threshold = threshold |
| 416 | + wavefunction_target.threshold_type = threshold_type |
| 417 | + |
| 418 | + scf_step = SCFSteps() |
| 419 | + if len(wf_values) > 1: |
| 420 | + # Convert to deltas based on data type |
| 421 | + if isinstance(wf_values[0], list): |
| 422 | + # Array data - compute element-wise deltas |
| 423 | + deltas = [] |
| 424 | + for i in range(1, len(wf_values)): |
| 425 | + delta = np.abs(np.array(wf_values[i]) - np.array(wf_values[i - 1])) |
| 426 | + deltas.append(delta) |
| 427 | + scf_step.delta_wavefunction_rms = deltas |
| 428 | + else: |
| 429 | + # Scalar data - compute simple deltas |
| 430 | + deltas = [ |
| 431 | + abs(wf_values[i] - wf_values[i - 1]) |
| 432 | + for i in range(1, len(wf_values)) |
| 433 | + ] |
| 434 | + scf_step.delta_wavefunction_rms = np.array(deltas) |
| 435 | + |
| 436 | + archive.data.outputs = [Outputs(scf_steps=scf_step)] |
| 437 | + is_reached = wavefunction_target.normalize(archive, logger) |
| 438 | + assert is_reached == expected_reached |
| 439 | + |
| 440 | + def test_wavefunction_missing_data(self, archive, logger, wavefunction_target): |
| 441 | + """Test wavefunction convergence with missing data.""" |
| 442 | + wavefunction_target.threshold = 1e-8 |
| 443 | + wavefunction_target.threshold_type = 'absolute' |
| 444 | + |
| 445 | + # No outputs at all |
| 446 | + is_reached = wavefunction_target.normalize(archive, logger) |
| 447 | + assert is_reached is None |
| 448 | + |
| 449 | + # Empty scf_steps |
| 450 | + archive.data.outputs = [Outputs(scf_steps=SCFSteps())] |
| 451 | + is_reached = wavefunction_target.normalize(archive, logger) |
| 452 | + assert is_reached is None |
| 453 | + |
| 454 | + def test_wavefunction_single_iteration(self, archive, logger, wavefunction_target): |
| 455 | + """Test with only one iteration (cannot compute convergence).""" |
| 456 | + wavefunction_target.threshold = 1e-8 |
| 457 | + wavefunction_target.threshold_type = 'absolute' |
| 458 | + |
| 459 | + # Single value - no delta can be computed |
| 460 | + scf_step = SCFSteps() |
| 461 | + # With only one iteration, there should be no delta_wavefunction_rms |
| 462 | + # or it should be empty |
| 463 | + archive.data.outputs = [Outputs(scf_steps=scf_step)] |
| 464 | + is_reached = wavefunction_target.normalize(archive, logger) |
| 465 | + assert is_reached is None |
| 466 | + |
| 467 | + def test_wavefunction_nan_values(self, archive, logger, wavefunction_target): |
| 468 | + """Test handling of NaN values in wavefunction data.""" |
| 469 | + wavefunction_target.threshold = 1e-8 |
| 470 | + wavefunction_target.threshold_type = 'absolute' |
| 471 | + |
| 472 | + scf_step = SCFSteps() |
| 473 | + scf_step.delta_wavefunction_rms = np.array([np.nan, 1e-10]) |
| 474 | + archive.data.outputs = [Outputs(scf_steps=scf_step)] |
| 475 | + |
| 476 | + # Should handle gracefully without crashing |
| 477 | + wavefunction_target.normalize(archive, logger) |
| 478 | + # Test passes if no exception is raised |
| 479 | + |
| 480 | + def test_wavefunction_negative_values(self, archive, logger, wavefunction_target): |
| 481 | + """Test that negative residuals are treated as absolute values.""" |
| 482 | + wavefunction_target.threshold = 1e-8 |
| 483 | + wavefunction_target.threshold_type = 'absolute' |
| 484 | + |
| 485 | + # Negative values should be treated as absolute |
| 486 | + scf_step = SCFSteps() |
| 487 | + scf_step.delta_wavefunction_rms = np.array([-1e-10, -5e-11]) |
| 488 | + archive.data.outputs = [Outputs(scf_steps=scf_step)] |
| 489 | + |
| 490 | + is_reached = wavefunction_target.normalize(archive, logger) |
| 491 | + # Should converge since abs values are below threshold |
| 492 | + assert is_reached is True |
| 493 | + |
| 494 | + |
363 | 495 | class TestConvergenceHelperMethods: |
364 | 496 | """Test the base class helper methods for convergence checking.""" |
365 | 497 |
|
|
0 commit comments