Skip to content

Commit 3bf8289

Browse files
committed
Added predict_diff
1 parent bb7d797 commit 3bf8289

File tree

2 files changed

+21
-6
lines changed

2 files changed

+21
-6
lines changed

climetlab_maelstrom_yr/yr.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def __init__(
4949
limit_predictors=None,
5050
probabilistic_target=False,
5151
normalize=False,
52+
predict_diff=False,
5253
verbose=False,
5354
):
5455
"""
@@ -60,6 +61,8 @@ def __init__(
6061
pattern (str): Pattern for filenames
6162
probabilistic_target (bool): If true, include target std as the second target parameter
6263
normalize (bool): If true, normalize the data
64+
predict_diff (bool): If true, change the target to be the difference between the target
65+
and raw forecast
6366
verbose (bool): Show debug statements if True
6467
"""
6568
if size not in ["5GB", "5TB"]:
@@ -70,8 +73,9 @@ def __init__(
7073

7174
self.size = size
7275
self.parameter = parameter
73-
self.do_normalize = normalize
7476
self.probabilistic_target = probabilistic_target
77+
self.do_normalize = normalize
78+
self.do_predict_diff = predict_diff
7579
self.verbose = verbose
7680

7781
is_url = location.find("://") >= 0
@@ -244,6 +248,10 @@ def preprocess(self, ds):
244248
else:
245249
data_vars["targets"] = (("leadtime", "y", "x", "target"), np.expand_dims(ds.variables["target_mean"], -1))
246250

251+
if self.do_predict_diff:
252+
I = np.where(coords["predictor"] == "air_temperature_2m")[0][0]
253+
data_vars["targets"][1][..., 0] -= data_vars["predictors"][1][..., I]
254+
247255
if self.do_normalize:
248256
for i, name in enumerate(coords["predictor"]):
249257
Yr.normalize(data_vars["predictors"][1][..., i], name)

tests/test_yr.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,11 @@ def test_read():
1919
)
2020
ds = cmlds.to_xarray()
2121
I = np.where(ds["predictor"] == "air_temperature_2m")[0][0]
22-
value0 = ds["predictors"][0, 0, 0, 0, I].values
23-
np.testing.assert_almost_equal(value0, 0.33352637)
22+
value = ds["predictors"][0, 0, 0, 0, I].values
23+
np.testing.assert_almost_equal(value, 0.33352637)
24+
25+
value = ds["targets"][0, 0, 0, 0, 0].values
26+
np.testing.assert_almost_equal(value, 0.3591344)
2427

2528
def test_normalize_dataset():
2629
cmlds = cml.load_dataset(
@@ -29,12 +32,17 @@ def test_normalize_dataset():
2932
dates=["2020-03-01", "2020-03-02"],
3033
# location=f"{dir}/",
3134
probabilistic_target=False,
35+
predict_diff=True,
3236
normalize=True,
3337
)
3438
ds = cmlds.to_xarray()
3539
I = np.where(ds["predictor"] == "air_temperature_2m")[0][0]
36-
value0 = ds["predictors"][0, 0, 0, 0, I].values
37-
np.testing.assert_almost_equal(value0, (0.33352637 - 5.388952960570653)/7.4335476655246735)
40+
value = ds["predictors"][0, 0, 0, 0, I].values
41+
np.testing.assert_almost_equal(value, (0.33352637 - 5.388952960570653)/7.4335476655246735)
42+
43+
# Check that the target has been normalized by air_temperature_2m
44+
value = ds["targets"][0, 0, 0, 0, 0].values
45+
np.testing.assert_almost_equal(value, 0.3591344 - 0.33352637)
3846

3947
def test_normalize_functions():
4048
ar = np.array([0, 1, 2], np.float32)
@@ -43,7 +51,6 @@ def test_normalize_functions():
4351
np.testing.assert_array_almost_equal(ar, [0, 1, 2])
4452

4553

46-
4754
if __name__ == "__main__":
4855
from climetlab.testing import main
4956

0 commit comments

Comments
 (0)