|
27 | 27 | import pandas as pd |
28 | 28 | import pytest |
29 | 29 | import xarray as xr |
| 30 | +from scipy import sparse |
30 | 31 | from scipy.sparse import csr_matrix |
31 | 32 |
|
32 | 33 | from climada.hazard.base import Hazard |
| 34 | +from climada.hazard.centroids.centr import Centroids |
33 | 35 | from climada.hazard.forecast import HazardForecast |
34 | 36 | from climada.hazard.test.test_base import hazard_kwargs |
35 | 37 |
|
@@ -355,6 +357,92 @@ def test_derived_select_null(self, haz_fc, haz_kwargs): |
355 | 357 | ) |
356 | 358 |
|
357 | 359 |
|
| 360 | +def test_check_sizes(haz_fc): |
| 361 | + """Test that _check_sizes validates matching lengths""" |
| 362 | + # Should pass with matching lengths |
| 363 | + haz_fc._check_sizes() |
| 364 | + |
| 365 | + # Test with mismatched member length - manipulate after creation |
| 366 | + haz_fc_bad = HazardForecast( |
| 367 | + lead_time=haz_fc.lead_time, |
| 368 | + member=haz_fc.member, |
| 369 | + event_id=haz_fc.event_id, |
| 370 | + event_name=haz_fc.event_name, |
| 371 | + date=haz_fc.date, |
| 372 | + haz_type=haz_fc.haz_type, |
| 373 | + units=haz_fc.units, |
| 374 | + centroids=haz_fc.centroids, |
| 375 | + intensity=haz_fc.intensity, |
| 376 | + fraction=haz_fc.fraction, |
| 377 | + ) |
| 378 | + # Manipulate member array directly to bypass __init__ validation |
| 379 | + haz_fc_bad.member = haz_fc.member[:-1] |
| 380 | + with pytest.raises(ValueError, match="Forecast.member"): |
| 381 | + haz_fc_bad._check_sizes() |
| 382 | + |
| 383 | + # Test with mismatched lead_time length |
| 384 | + haz_fc_bad2 = HazardForecast( |
| 385 | + lead_time=haz_fc.lead_time, |
| 386 | + member=haz_fc.member, |
| 387 | + event_id=haz_fc.event_id, |
| 388 | + event_name=haz_fc.event_name, |
| 389 | + date=haz_fc.date, |
| 390 | + haz_type=haz_fc.haz_type, |
| 391 | + units=haz_fc.units, |
| 392 | + centroids=haz_fc.centroids, |
| 393 | + intensity=haz_fc.intensity, |
| 394 | + fraction=haz_fc.fraction, |
| 395 | + ) |
| 396 | + # Manipulate lead_time array directly to bypass __init__ validation |
| 397 | + haz_fc_bad2.lead_time = haz_fc.lead_time[:-1] |
| 398 | + with pytest.raises(ValueError, match="Forecast.lead_time"): |
| 399 | + haz_fc_bad2._check_sizes() |
| 400 | + |
| 401 | + |
| 402 | +def test_set_event_attrs_from_forecast_dims(): |
| 403 | + """Test that _set_event_attrs_from_forecast_dims generates event attributes correctly""" |
| 404 | + lead_time = pd.timedelta_range("3h", periods=4, freq="2h").to_numpy() |
| 405 | + member = np.array([1, 2, 3, 4]) |
| 406 | + |
| 407 | + # Create a HazardForecast without event_name and date (they will be auto-generated) |
| 408 | + haz_fc = HazardForecast( |
| 409 | + lead_time=lead_time, |
| 410 | + member=member, |
| 411 | + haz_type="TC", |
| 412 | + units="m/s", |
| 413 | + event_id=np.array([10, 20, 30, 40]), |
| 414 | + intensity=sparse.csr_matrix(np.random.rand(4, 3)), |
| 415 | + centroids=Centroids(lat=np.array([1, 2, 3]), lon=np.array([4, 5, 6])), |
| 416 | + ) |
| 417 | + |
| 418 | + # Check that event_name was auto-generated |
| 419 | + assert len(haz_fc.event_name) == 4 |
| 420 | + assert haz_fc.event_name[0] == "lt_3h_m_1" |
| 421 | + assert haz_fc.event_name[1] == "lt_5h_m_2" |
| 422 | + assert haz_fc.event_name[2] == "lt_7h_m_3" |
| 423 | + assert haz_fc.event_name[3] == "lt_9h_m_4" |
| 424 | + |
| 425 | + # Check that date was set to zeros |
| 426 | + npt.assert_array_equal(haz_fc.date, np.zeros(4, dtype=int)) |
| 427 | + |
| 428 | + # Test that it raises error when lead_time and member have different lengths |
| 429 | + haz_fc_bad = HazardForecast( |
| 430 | + lead_time=lead_time, |
| 431 | + member=member, |
| 432 | + haz_type="TC", |
| 433 | + units="m/s", |
| 434 | + event_id=np.array([10, 20, 30, 40]), |
| 435 | + event_name=["a", "b", "c", "d"], # Provide event_name to bypass auto-generation |
| 436 | + date=np.array([1, 2, 3, 4]), |
| 437 | + intensity=sparse.csr_matrix(np.random.rand(4, 3)), |
| 438 | + centroids=Centroids(lat=np.array([1, 2, 3]), lon=np.array([4, 5, 6])), |
| 439 | + ) |
| 440 | + # Now manipulate arrays to create mismatch and call the method directly |
| 441 | + haz_fc_bad.member = member[:-1] |
| 442 | + with pytest.raises(ValueError, match="Length mismatch"): |
| 443 | + haz_fc_bad._set_event_attrs_from_forecast_dims() |
| 444 | + |
| 445 | + |
358 | 446 | def test_write_read_hazard_forecast(haz_fc, tmp_path): |
359 | 447 |
|
360 | 448 | file_name = tmp_path / "test_hazard_forecast.h5" |
|
0 commit comments