Skip to content

Commit 4101cec

Browse files
partial_fit in MTS
1 parent 62928cd commit 4101cec

File tree

3 files changed

+294
-16
lines changed

3 files changed

+294
-16
lines changed

nnetsauce/mts/mts.py

Lines changed: 292 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -623,6 +623,9 @@ def fit(self, X, xreg=None, **kwargs):
623623
def partial_fit(self, X, xreg=None, **kwargs):
624624
"""Update the model with new observations X, with optional regressors xreg
625625
626+
This is essentially a copy of fit() but uses partial_fit() on the underlying models
627+
when available, otherwise falls back to fit().
628+
626629
Parameters:
627630
628631
X: {array-like}, shape = [n_samples, n_features]
@@ -640,25 +643,303 @@ def partial_fit(self, X, xreg=None, **kwargs):
640643
641644
self: object
642645
"""
646+
647+
# Check if this is the first call (no previous fit)
648+
first_fit = self.df_ is None
649+
650+
if first_fit:
651+
# First time: use regular fit
652+
return self.fit(X, xreg, **kwargs)
653+
654+
# === Copy of fit() method with partial_fit modifications ===
655+
656+
try:
657+
self.init_n_series_ = X.shape[1]
658+
except IndexError as e:
659+
self.init_n_series_ = 1
660+
661+
# Automatic lag selection if requested (same as fit)
662+
if isinstance(self.lags, str):
663+
max_lags = min(25, X.shape[0] // 4)
664+
best_ic = float("inf")
665+
best_lags = 1
666+
667+
if self.verbose:
668+
print(f"\nSelecting optimal number of lags using {self.lags}...")
669+
iterator = tqdm(range(1, max_lags + 1))
670+
else:
671+
iterator = range(1, max_lags + 1)
643672

644-
assert self.df_ is not None, "fit() must be called before partial_fit()"
673+
for lag in iterator:
674+
# Convert DataFrame to numpy array before reversing
675+
if isinstance(X, pd.DataFrame):
676+
X_values = X.values[::-1]
677+
else:
678+
X_values = X[::-1]
645679

646-
if (isinstance(X, pd.DataFrame) is False) and isinstance(
647-
X, pd.Series
648-
) is False:
649-
if len(X.shape) == 1:
650-
X = X.reshape(1, -1)
680+
# Try current lag value
681+
if self.init_n_series_ > 1:
682+
mts_input = ts.create_train_inputs(X_values, lag)
683+
else:
684+
mts_input = ts.create_train_inputs(X_values.reshape(-1, 1), lag)
651685

652-
return self.fit(X, xreg, **kwargs)
686+
# Cook training set and fit model
687+
dummy_y, scaled_Z = self.cook_training_set(
688+
y=np.ones(mts_input[0].shape[0]), X=mts_input[1]
689+
)
690+
residuals_ = []
691+
692+
for i in range(self.init_n_series_):
693+
y_mean = np.mean(mts_input[0][:, i])
694+
centered_y_i = mts_input[0][:, i] - y_mean
695+
self.obj.fit(X=scaled_Z, y=centered_y_i)
696+
residuals_.append(
697+
(centered_y_i - self.obj.predict(scaled_Z)).tolist()
698+
)
699+
700+
self.residuals_ = np.asarray(residuals_).T
701+
ic = self._compute_information_criterion(
702+
curr_lags=lag, criterion=self.lags
703+
)
704+
705+
if self.verbose:
706+
print(f"Trying lags={lag}, {self.lags}={ic:.2f}")
707+
708+
if ic < best_ic:
709+
best_ic = ic
710+
best_lags = lag
711+
712+
if self.verbose:
713+
print(f"\nSelected {best_lags} lags with {self.lags}={best_ic:.2f}")
714+
715+
self.lags = best_lags
653716

717+
# Data preprocessing (same as fit)
718+
if isinstance(X, pd.DataFrame) is False:
719+
# input data set is a numpy array
720+
if xreg is None:
721+
X = pd.DataFrame(X)
722+
self.series_names = ["series" + str(i) for i in range(X.shape[1])]
723+
else:
724+
# xreg is not None
725+
X = mo.cbind(X, xreg)
726+
self.xreg_ = xreg
727+
else: # input data set is a DataFrame with column names
728+
X_index = None
729+
if X.index is not None:
730+
X_index = X.index
731+
if xreg is None:
732+
X = copy.deepcopy(mo.convert_df_to_numeric(X))
733+
else:
734+
X = copy.deepcopy(mo.cbind(mo.convert_df_to_numeric(X), xreg))
735+
self.xreg_ = xreg
736+
if X_index is not None:
737+
X.index = X_index
738+
self.series_names = X.columns.tolist()
739+
740+
# Data concatenation (same as fit)
741+
if isinstance(X, pd.DataFrame):
742+
if self.df_ is None:
743+
self.df_ = X
744+
X = X.values
745+
else:
746+
input_dates_prev = pd.DatetimeIndex(self.df_.index.values)
747+
frequency = pd.infer_freq(input_dates_prev)
748+
self.df_ = pd.concat([self.df_, X], axis=0)
749+
self.input_dates = pd.date_range(
750+
start=input_dates_prev[0],
751+
periods=len(input_dates_prev) + X.shape[0],
752+
freq=frequency,
753+
).values.tolist()
754+
self.df_.index = self.input_dates
755+
X = self.df_.values
756+
self.df_.columns = self.series_names
654757
else:
655-
if len(X.shape) == 1:
656-
X = pd.DataFrame(
657-
X.values.reshape(1, -1), columns=self.df_.columns
758+
if self.df_ is None:
759+
self.df_ = pd.DataFrame(X, columns=self.series_names)
760+
else:
761+
self.df_ = pd.concat(
762+
[self.df_, pd.DataFrame(X, columns=self.series_names)],
763+
axis=0,
658764
)
659765

660-
return self.fit(X, xreg, **kwargs)
766+
self.input_dates = ts.compute_input_dates(self.df_)
767+
768+
try:
769+
# multivariate time series
770+
n, p = X.shape
771+
except:
772+
# univariate time series
773+
n = X.shape[0]
774+
p = 1
775+
self.n_obs_ = n
776+
777+
rep_1_n = np.repeat(1, n)
661778

779+
self.y_ = None
780+
self.X_ = None
781+
self.n_series = p
782+
# NOTE: Don't clear fit_objs_ and y_means_ for partial_fit
783+
# self.fit_objs_.clear() # REMOVED for partial_fit
784+
# self.y_means_.clear() # REMOVED for partial_fit
785+
residuals_ = []
786+
self.residuals_ = None
787+
self.residuals_sims_ = None
788+
self.kde_ = None
789+
self.sims_ = None
790+
self.scaled_Z_ = None
791+
self.centered_y_is_ = []
792+
793+
if self.init_n_series_ > 1:
794+
# multivariate time series
795+
mts_input = ts.create_train_inputs(X[::-1], self.lags)
796+
else:
797+
# univariate time series
798+
mts_input = ts.create_train_inputs(X.reshape(-1, 1)[::-1], self.lags)
799+
800+
self.y_ = mts_input[0]
801+
self.X_ = mts_input[1]
802+
803+
dummy_y, scaled_Z = self.cook_training_set(y=rep_1_n, X=self.X_)
804+
self.scaled_Z_ = scaled_Z
805+
806+
# loop on all the time series and adjust self.obj - MODIFIED for partial_fit
807+
if self.verbose > 0:
808+
print(f"\n Partially fitting {type(self.obj).__name__} to multivariate time series... \n")
809+
810+
if self.show_progress is True:
811+
iterator = tqdm(range(self.init_n_series_))
812+
else:
813+
iterator = range(self.init_n_series_)
814+
815+
if self.type_pi in (
816+
"gaussian",
817+
"kde",
818+
"bootstrap",
819+
"block-bootstrap",
820+
) or self.type_pi.startswith("vine"):
821+
for i in iterator:
822+
y_mean = np.mean(self.y_[:, i])
823+
self.y_means_[i] = y_mean
824+
centered_y_i = self.y_[:, i] - y_mean
825+
self.centered_y_is_.append(centered_y_i)
826+
827+
# KEY CHANGE: Use partial_fit if available, otherwise fall back to fit
828+
if hasattr(self.fit_objs_[i], 'partial_fit'):
829+
try:
830+
self.fit_objs_[i].partial_fit(X=scaled_Z, y=centered_y_i)
831+
except Exception as e:
832+
if self.verbose > 0:
833+
print(f"partial_fit failed for series {i}, using fit(): {e}")
834+
self.fit_objs_[i].fit(X=scaled_Z, y=centered_y_i)
835+
else:
836+
self.fit_objs_[i].fit(X=scaled_Z, y=centered_y_i)
837+
838+
residuals_.append(
839+
(centered_y_i - self.fit_objs_[i].predict(scaled_Z)).tolist()
840+
)
841+
842+
if self.type_pi == "quantile":
843+
for i in iterator:
844+
y_mean = np.mean(self.y_[:, i])
845+
self.y_means_[i] = y_mean
846+
centered_y_i = self.y_[:, i] - y_mean
847+
self.centered_y_is_.append(centered_y_i)
848+
849+
# KEY CHANGE: Use partial_fit if available
850+
if hasattr(self.fit_objs_[i], 'partial_fit'):
851+
try:
852+
self.fit_objs_[i].partial_fit(X=scaled_Z, y=centered_y_i)
853+
except Exception as e:
854+
if self.verbose > 0:
855+
print(f"partial_fit failed for series {i}, using fit(): {e}")
856+
self.fit_objs_[i].fit(X=scaled_Z, y=centered_y_i)
857+
else:
858+
self.fit_objs_[i].fit(X=scaled_Z, y=centered_y_i)
859+
860+
if self.type_pi.startswith("scp"):
861+
# split conformal prediction
862+
for i in iterator:
863+
n_y = self.y_.shape[0]
864+
n_y_half = n_y // 2
865+
first_half_idx = range(0, n_y_half)
866+
second_half_idx = range(n_y_half, n_y)
867+
y_mean_temp = np.mean(self.y_[first_half_idx, i])
868+
centered_y_i_temp = self.y_[first_half_idx, i] - y_mean_temp
869+
870+
# KEY CHANGE: Use partial_fit if available for first half
871+
if hasattr(self.fit_objs_[i], 'partial_fit'):
872+
try:
873+
self.fit_objs_[i].partial_fit(X=scaled_Z[first_half_idx, :], y=centered_y_i_temp)
874+
except Exception as e:
875+
if self.verbose > 0:
876+
print(f"partial_fit failed for series {i} (first half), using fit(): {e}")
877+
self.fit_objs_[i].fit(X=scaled_Z[first_half_idx, :], y=centered_y_i_temp)
878+
else:
879+
self.fit_objs_[i].fit(X=scaled_Z[first_half_idx, :], y=centered_y_i_temp)
880+
881+
# calibrated residuals actually
882+
residuals_.append(
883+
(
884+
self.y_[second_half_idx, i]
885+
- (y_mean_temp + self.fit_objs_[i].predict(scaled_Z[second_half_idx, :]))
886+
).tolist()
887+
)
888+
889+
# fit on the second half
890+
y_mean = np.mean(self.y_[second_half_idx, i])
891+
self.y_means_[i] = y_mean
892+
centered_y_i = self.y_[second_half_idx, i] - y_mean
893+
894+
# KEY CHANGE: Use partial_fit if available for second half
895+
if hasattr(self.fit_objs_[i], 'partial_fit'):
896+
try:
897+
self.fit_objs_[i].partial_fit(X=scaled_Z[second_half_idx, :], y=centered_y_i)
898+
except Exception as e:
899+
if self.verbose > 0:
900+
print(f"partial_fit failed for series {i} (second half), using fit(): {e}")
901+
self.fit_objs_[i].fit(X=scaled_Z[second_half_idx, :], y=centered_y_i)
902+
else:
903+
self.fit_objs_[i].fit(X=scaled_Z[second_half_idx, :], y=centered_y_i)
904+
905+
# Rest of the method is identical to fit()
906+
self.residuals_ = np.asarray(residuals_).T
907+
908+
if self.type_pi == "gaussian":
909+
self.gaussian_preds_std_ = np.std(self.residuals_, axis=0)
910+
911+
if self.type_pi.startswith("scp2"):
912+
# Calculate mean and standard deviation for each column
913+
data_mean = np.mean(self.residuals_, axis=0)
914+
self.residuals_std_dev_ = np.std(self.residuals_, axis=0)
915+
# Center and scale the array using broadcasting
916+
self.residuals_ = (
917+
self.residuals_ - data_mean[np.newaxis, :]
918+
) / self.residuals_std_dev_[np.newaxis, :]
919+
920+
if self.replications != None and "kde" in self.type_pi:
921+
if self.verbose > 0:
922+
print(f"\n Simulate residuals using {self.kernel} kernel... \n")
923+
assert self.kernel in (
924+
"gaussian",
925+
"tophat",
926+
), "currently, 'kernel' must be either 'gaussian' or 'tophat'"
927+
kernel_bandwidths = {"bandwidth": np.logspace(-6, 6, 150)}
928+
grid = GridSearchCV(
929+
KernelDensity(kernel=self.kernel, **kwargs),
930+
param_grid=kernel_bandwidths,
931+
)
932+
grid.fit(self.residuals_)
933+
934+
if self.verbose > 0:
935+
print(
936+
f"\n Best parameters for {self.kernel} kernel: {grid.best_params_} \n"
937+
)
938+
939+
self.kde_ = grid.best_estimator_
940+
941+
return self
942+
662943
def _predict_quantiles(self, h, quantiles, **kwargs):
663944
"""Predict arbitrary quantiles from simulated paths."""
664945
# Ensure output dates are set

nnetsauce/ridge2/ridge2MultitaskClassifier.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -391,10 +391,7 @@ def partial_fit(self, X, y, classes=None, learning_rate=0.01, decay=0.001, **kwa
391391
392392
Returns:
393393
self: object
394-
"""
395-
396-
import numpy as np
397-
394+
"""
398395
# Input validation
399396
X = np.asarray(X)
400397
y = np.asarray(y)

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from codecs import open
44
from os import path
55

6-
__version__ = '0.40.1'
6+
__version__ = '0.40.2'
77

88
# get the dependencies and installs
99
here = path.abspath(path.dirname(__file__))

0 commit comments

Comments
 (0)