Skip to content

Commit 464e592

Browse files
committed
add crs and private method tests
1 parent c9041fe commit 464e592

File tree

2 files changed

+89
-0
lines changed

2 files changed

+89
-0
lines changed

climada/hazard/forecast.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,7 @@ def from_xarray_raster(
339339
data=dset_squeezed,
340340
coordinate_vars=parent_coord_vars,
341341
intensity=intensity,
342+
crs=crs,
342343
)
343344

344345
kwargs = reader.get_hazard_kwargs() | {

climada/hazard/test/test_forecast.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,11 @@
2727
import pandas as pd
2828
import pytest
2929
import xarray as xr
30+
from scipy import sparse
3031
from scipy.sparse import csr_matrix
3132

3233
from climada.hazard.base import Hazard
34+
from climada.hazard.centroids.centr import Centroids
3335
from climada.hazard.forecast import HazardForecast
3436
from climada.hazard.test.test_base import hazard_kwargs
3537

@@ -355,6 +357,92 @@ def test_derived_select_null(self, haz_fc, haz_kwargs):
355357
)
356358

357359

360+
def test_check_sizes(haz_fc):
361+
"""Test that _check_sizes validates matching lengths"""
362+
# Should pass with matching lengths
363+
haz_fc._check_sizes()
364+
365+
# Test with mismatched member length - manipulate after creation
366+
haz_fc_bad = HazardForecast(
367+
lead_time=haz_fc.lead_time,
368+
member=haz_fc.member,
369+
event_id=haz_fc.event_id,
370+
event_name=haz_fc.event_name,
371+
date=haz_fc.date,
372+
haz_type=haz_fc.haz_type,
373+
units=haz_fc.units,
374+
centroids=haz_fc.centroids,
375+
intensity=haz_fc.intensity,
376+
fraction=haz_fc.fraction,
377+
)
378+
# Manipulate member array directly to bypass __init__ validation
379+
haz_fc_bad.member = haz_fc.member[:-1]
380+
with pytest.raises(ValueError, match="Forecast.member"):
381+
haz_fc_bad._check_sizes()
382+
383+
# Test with mismatched lead_time length
384+
haz_fc_bad2 = HazardForecast(
385+
lead_time=haz_fc.lead_time,
386+
member=haz_fc.member,
387+
event_id=haz_fc.event_id,
388+
event_name=haz_fc.event_name,
389+
date=haz_fc.date,
390+
haz_type=haz_fc.haz_type,
391+
units=haz_fc.units,
392+
centroids=haz_fc.centroids,
393+
intensity=haz_fc.intensity,
394+
fraction=haz_fc.fraction,
395+
)
396+
# Manipulate lead_time array directly to bypass __init__ validation
397+
haz_fc_bad2.lead_time = haz_fc.lead_time[:-1]
398+
with pytest.raises(ValueError, match="Forecast.lead_time"):
399+
haz_fc_bad2._check_sizes()
400+
401+
402+
def test_set_event_attrs_from_forecast_dims():
403+
"""Test that _set_event_attrs_from_forecast_dims generates event attributes correctly"""
404+
lead_time = pd.timedelta_range("3h", periods=4, freq="2h").to_numpy()
405+
member = np.array([1, 2, 3, 4])
406+
407+
# Create a HazardForecast without event_name and date (they will be auto-generated)
408+
haz_fc = HazardForecast(
409+
lead_time=lead_time,
410+
member=member,
411+
haz_type="TC",
412+
units="m/s",
413+
event_id=np.array([10, 20, 30, 40]),
414+
intensity=sparse.csr_matrix(np.random.rand(4, 3)),
415+
centroids=Centroids(lat=np.array([1, 2, 3]), lon=np.array([4, 5, 6])),
416+
)
417+
418+
# Check that event_name was auto-generated
419+
assert len(haz_fc.event_name) == 4
420+
assert haz_fc.event_name[0] == "lt_3h_m_1"
421+
assert haz_fc.event_name[1] == "lt_5h_m_2"
422+
assert haz_fc.event_name[2] == "lt_7h_m_3"
423+
assert haz_fc.event_name[3] == "lt_9h_m_4"
424+
425+
# Check that date was set to zeros
426+
npt.assert_array_equal(haz_fc.date, np.zeros(4, dtype=int))
427+
428+
# Test that it raises error when lead_time and member have different lengths
429+
haz_fc_bad = HazardForecast(
430+
lead_time=lead_time,
431+
member=member,
432+
haz_type="TC",
433+
units="m/s",
434+
event_id=np.array([10, 20, 30, 40]),
435+
event_name=["a", "b", "c", "d"], # Provide event_name to bypass auto-generation
436+
date=np.array([1, 2, 3, 4]),
437+
intensity=sparse.csr_matrix(np.random.rand(4, 3)),
438+
centroids=Centroids(lat=np.array([1, 2, 3]), lon=np.array([4, 5, 6])),
439+
)
440+
# Now manipulate arrays to create mismatch and call the method directly
441+
haz_fc_bad.member = member[:-1]
442+
with pytest.raises(ValueError, match="Length mismatch"):
443+
haz_fc_bad._set_event_attrs_from_forecast_dims()
444+
445+
358446
def test_write_read_hazard_forecast(haz_fc, tmp_path):
359447

360448
file_name = tmp_path / "test_hazard_forecast.h5"

0 commit comments

Comments
 (0)