Skip to content

Commit 4d7967d

Browse files
export PredictionSet
1 parent 40625af commit 4d7967d

File tree

4 files changed

+10
-2
lines changed

4 files changed

+10
-2
lines changed

examples/mts_timeseries_xreg.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def extract_month_year(df, date_column='date'):
8585

8686
# Fit model
8787
model = ns.MTS(RidgeCV(alphas=10**np.linspace(-3, 3, 100)),
88-
replications=100,
88+
replications=5,
8989
lags=25,
9090
type_pi="scp2-kde",
9191
kernel='gaussian',

nnetsauce/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from .neuralnet.neuralnetclassification import NeuralNetClassifier
2727
from .optimizers.optimizer import Optimizer
2828
from .predictioninterval import PredictionInterval
29+
from .predictionset import PredictionSet
2930
from .quantile.quantileregression import QuantileRegressor
3031
from .quantile.quantileclassification import QuantileClassifier
3132
from .randombag.randomBagClassifier import RandomBagClassifier
@@ -72,6 +73,7 @@
7273
"NeuralNetRegressor",
7374
"NeuralNetClassifier",
7475
"PredictionInterval",
76+
"PredictionSet",
7577
"SimpleMultitaskClassifier",
7678
"Optimizer",
7779
"QuantileRegressor",

nnetsauce/mts/mts.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1402,6 +1402,8 @@ def predict(self, h=5, level=95, quantiles=None, **kwargs):
14021402

14031403
res = DescribeResult(self.mean_, self.lower_, self.upper_)
14041404

1405+
print("\n res", res)
1406+
14051407
if self.xreg_ is not None:
14061408
if len(self.xreg_.shape) > 1:
14071409
res2 = mx.tuple_map(
@@ -1450,6 +1452,8 @@ def predict(self, h=5, level=95, quantiles=None, **kwargs):
14501452

14511453
res = DescribeResult(self.mean_, self.lower_, self.upper_)
14521454

1455+
print("\n res", res)
1456+
14531457
if self.xreg_ is not None:
14541458
if len(self.xreg_.shape) > 1:
14551459
res2 = mx.tuple_map(
@@ -1478,6 +1482,8 @@ def predict(self, h=5, level=95, quantiles=None, **kwargs):
14781482

14791483
res = DescribeResult(self.mean_)
14801484

1485+
print("\n res", res)
1486+
14811487
if self.xreg_ is not None:
14821488
if len(self.xreg_.shape) > 1:
14831489
res2 = mx.tuple_map(

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from codecs import open
44
from os import path
55

6-
__version__ = '0.43.0'
6+
__version__ = '0.44.0'
77

88
# get the dependencies and installs
99
here = path.abspath(path.dirname(__file__))

0 commit comments

Comments
 (0)