Skip to content

Commit ed8d813

Browse files
Backports v0.14.3 (#3077)
* Fix Rotbaum serialization and deserialization (#3068) * Fix Rotbaum to handle short series (#3073) --------- Co-authored-by: Anurag Pant <anuragpant@cs.ucla.edu>
1 parent 3c434d8 commit ed8d813

File tree

5 files changed

+133
-10
lines changed

5 files changed

+133
-10
lines changed

src/gluonts/ext/rotbaum/_model.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import gc
2121
from collections import defaultdict
2222

23-
from gluonts.core.component import validated
23+
from gluonts.core.component import equals, validated
2424

2525

2626
class QRF:
@@ -121,6 +121,13 @@ def _create_xgboost_model(model_params: Optional[dict] = None):
121121
}
122122
return xgboost.sklearn.XGBModel(**model_params)
123123

124+
def __eq__(self, that):
125+
"""
126+
Two QRX instances are considered equal if they have the same
127+
constructor arguments.
128+
"""
129+
return equals(self, that)
130+
124131
def fit(
125132
self,
126133
x_train: Union[pd.DataFrame, List],

src/gluonts/ext/rotbaum/_predictor.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,14 @@
1313

1414
import concurrent.futures
1515
import logging
16+
import pickle
1617
from itertools import chain
1718
from typing import Iterator, List, Optional, Any, Dict
1819
from toolz import first
1920

2021
import numpy as np
2122
import pandas as pd
23+
from pathlib import Path
2224
from itertools import compress
2325

2426
from gluonts.core.component import validated
@@ -340,6 +342,31 @@ def predict( # type: ignore
340342
item_id=ts.get("item_id"),
341343
)
342344

345+
def serialize(self, path: Path) -> None:
346+
"""
347+
This function calls parent class serialize() in order to serialize
348+
the class name, version information and constuctor arguments. It
349+
persists the tree predictor by pickling the model list that is
350+
generated when pickling the TreePredictor.
351+
"""
352+
super().serialize(path)
353+
with (path / "predictor.pkl").open("wb") as f:
354+
pickle.dump(self.model_list, f)
355+
356+
@classmethod
357+
def deserialize(cls, path: Path, **kwargs: Any) -> "TreePredictor":
358+
"""
359+
This function loads and returns the serialized model. It loads
360+
the predictor class with the serialized arguments. It then loads
361+
the trained model list by reading the pickle file.
362+
"""
363+
364+
predictor = super().deserialize(path)
365+
assert isinstance(predictor, cls)
366+
with (path / "predictor.pkl").open("rb") as f:
367+
predictor.model_list = pickle.load(f)
368+
return predictor
369+
343370
def explain(
344371
self, importance_type: str = "gain", percentage: bool = True
345372
) -> ExplanationResult:

src/gluonts/ext/rotbaum/_preprocess.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -452,9 +452,12 @@ def make_features(self, time_series: Dict, starting_index: int) -> List:
452452
end_index = starting_index + self.context_window_size
453453
if starting_index < 0:
454454
prefix = [None] * abs(starting_index)
455+
time_series_window = time_series["target"]
455456
else:
456457
prefix = []
457-
time_series_window = time_series["target"][starting_index:end_index]
458+
time_series_window = time_series["target"][
459+
starting_index:end_index
460+
]
458461
only_lag_features, transform_dict = self._pre_transform(
459462
time_series_window, self.subtract_mean, self.count_nans
460463
)
@@ -464,7 +467,10 @@ def make_features(self, time_series: Dict, starting_index: int) -> List:
464467
if self.use_feat_static_real
465468
else []
466469
)
467-
if self.cardinality:
470+
if (
471+
self.cardinality
472+
and time_series.get("feat_static_cat", None) is not None
473+
):
468474
feat_static_cat = (
469475
self.encode_one_hot_all(time_series["feat_static_cat"])
470476
if self.one_hot_encode
@@ -477,10 +483,10 @@ def make_features(self, time_series: Dict, starting_index: int) -> List:
477483
list(
478484
chain(
479485
*[
480-
list(ent[0]) + list(ent[1].values())
486+
prefix + list(ent[0]) + list(ent[1].values())
481487
for ent in [
482488
self._pre_transform(
483-
ts[starting_index:end_index],
489+
ts if prefix else ts[starting_index:end_index],
484490
self.subtract_mean,
485491
self.count_nans,
486492
)

test/ext/rotbaum/test_model.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,11 @@
1111
# express or implied. See the License for the specific language governing
1212
# permissions and limitations under the License.
1313

14-
14+
from pathlib import Path
1515
import pytest
16+
import tempfile
1617

17-
from gluonts.ext.rotbaum import TreeEstimator
18+
from gluonts.ext.rotbaum import TreeEstimator, TreePredictor
1819

1920

2021
@pytest.fixture()
@@ -33,5 +34,20 @@ def test_accuracy(accuracy_test, hyperparameters, quantiles):
3334
accuracy_test(TreeEstimator, hyperparameters, accuracy=0.20)
3435

3536

36-
def test_serialize(serialize_test, hyperparameters):
37-
serialize_test(TreeEstimator, hyperparameters)
37+
def test_serialize(serialize_test, hyperparameters, dsinfo):
38+
forecaster = TreeEstimator.from_hyperparameters(
39+
freq=dsinfo.freq,
40+
**{
41+
"prediction_length": dsinfo.prediction_length,
42+
"num_parallel_samples": dsinfo.num_parallel_samples,
43+
},
44+
**hyperparameters,
45+
)
46+
47+
predictor_act = forecaster.train(dsinfo.train_ds)
48+
49+
with tempfile.TemporaryDirectory() as temp_dir:
50+
predictor_act.serialize(Path(temp_dir))
51+
predictor_exp = TreePredictor.deserialize(Path(temp_dir))
52+
assert predictor_act == predictor_exp
53+
assert predictor_act.model_list == predictor_exp.model_list

test/ext/rotbaum/test_rotbaum_smoke.py

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,12 @@
1212
# permissions and limitations under the License.
1313

1414
import pytest
15+
import numpy as np
1516

16-
from gluonts.ext.rotbaum import TreeEstimator
17+
from gluonts.ext.rotbaum import TreeEstimator, TreePredictor
1718

1819
from gluonts.testutil.dummy_datasets import make_dummy_datasets_with_features
20+
from gluonts.dataset.common import ListDataset
1921

2022
# TODO: Add support for categorical and dynamic features.
2123

@@ -59,3 +61,68 @@ def test_rotbaum_smoke(datasets):
5961
predictor = estimator.train(dataset_train)
6062
forecasts = list(predictor.predict(dataset_test))
6163
assert len(forecasts) == len(dataset_test)
64+
65+
66+
def test_short_history_item_pred():
67+
prediction_length = 7
68+
freq = "D"
69+
70+
dataset = ListDataset(
71+
data_iter=[
72+
{
73+
"start": "2017-10-11",
74+
"item_id": "item_1",
75+
"target": np.array(
76+
[
77+
1.0,
78+
9.0,
79+
2.0,
80+
0.0,
81+
0.0,
82+
1.0,
83+
5.0,
84+
3.0,
85+
4.0,
86+
2.0,
87+
0.0,
88+
0.0,
89+
1.0,
90+
6.0,
91+
]
92+
),
93+
"feat_static_cat": np.array([0.0, 0.0], dtype=float),
94+
"past_feat_dynamic_real": np.array(
95+
[
96+
[1.0222e06 for i in range(14)],
97+
[750.0 for i in range(14)],
98+
]
99+
),
100+
},
101+
{
102+
"start": "2017-10-11",
103+
"item_id": "item_2",
104+
"target": np.array([7.0, 0.0, 0.0, 23.0, 13.0]),
105+
"feat_static_cat": np.array([0.0, 1.0], dtype=float),
106+
"past_feat_dynamic_real": np.array(
107+
[[0 for i in range(5)], [750.0 for i in range(5)]]
108+
),
109+
},
110+
],
111+
freq=freq,
112+
)
113+
114+
predictor = TreePredictor(
115+
freq=freq,
116+
prediction_length=prediction_length,
117+
quantiles=[0.1, 0.5, 0.9],
118+
max_n_datapts=50000,
119+
method="QuantileRegression",
120+
use_past_feat_dynamic_real=True,
121+
use_feat_dynamic_real=False,
122+
use_feat_dynamic_cat=False,
123+
use_feat_static_real=False,
124+
cardinality="auto",
125+
)
126+
predictor = predictor.train(dataset)
127+
forecasts = list(predictor.predict(dataset))
128+
assert forecasts[1].quantile(0.5).shape[0] == prediction_length

0 commit comments

Comments
 (0)