Skip to content

Commit da0c914

Browse files
add init test
1 parent 3db5803 commit da0c914

File tree

2 files changed

+37
-10
lines changed

2 files changed

+37
-10
lines changed

climada/util/forecast.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,15 +23,18 @@
2323

2424

2525
class Forecast:
26-
def __init__(self, lead_time=None, member=None, *args, **kwargs):
27-
if lead_time is None:
28-
self.lead_time = np.array([])
29-
else:
30-
self.lead_time = np.array(lead_time)
31-
32-
if member is None:
33-
self.member = np.array([])
34-
else:
35-
self.member = member
26+
def __init__(
27+
self,
28+
lead_time: np.ndarray | None = None,
29+
member: np.ndarray | None = None,
30+
*args,
31+
**kwargs,
32+
):
33+
34+
self.lead_time = (
35+
np.asarray(lead_time) if lead_time is not None else np.array([])
36+
)
37+
38+
self.member = np.asarray(member) if member is not None else np.array([])
3639

3740
super().__init__(*args, **kwargs)

climada/util/test/test_forecast.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,27 @@
1818
1919
Tests for Forecast base class.
2020
"""
21+
22+
import numpy as np
23+
import numpy.testing as npt
24+
import pytest
25+
26+
from climada.util.forecast import Forecast
27+
28+
29+
def test_forecast_init():
30+
"""Test initialization of Forecast class."""
31+
forecast = Forecast()
32+
npt.assert_array_equal(forecast.lead_time, np.array([]))
33+
npt.assert_array_equal(forecast.member, np.array([]))
34+
35+
forecast = Forecast(member=np.array([1, 2]))
36+
npt.assert_array_equal(forecast.member, np.array([1, 2]), strict=True)
37+
38+
forecast = Forecast(lead_time=np.array([1, 2]))
39+
npt.assert_array_equal(forecast.lead_time, np.array([1, 2]), strict=True)
40+
41+
forecast = Forecast(lead_time=np.array([1, 2]), member=[3, 4])
42+
npt.assert_array_equal(forecast.lead_time, np.array([1, 2]), strict=True)
43+
npt.assert_array_equal(forecast.member, np.array([3, 4]), strict=True)
44+
assert isinstance(forecast.member, np.ndarray)

0 commit comments

Comments
 (0)