Skip to content

Commit 22315be

Browse files
authored
[FEAT] HuberIQLoss (#1307)
1 parent 2de5a82 commit 22315be

File tree

10 files changed

+305
-21
lines changed

10 files changed

+305
-21
lines changed

nbs/common.base_model.ipynb

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -272,14 +272,14 @@
272272
" raise Exception(f'{type(self).__name__} does not support static exogenous variables.')\n",
273273
"\n",
274274
" # Protections for loss functions\n",
275-
" if isinstance(self.loss, (losses.IQLoss)):\n",
275+
" if isinstance(self.loss, (losses.IQLoss, losses.HuberIQLoss)):\n",
276276
" loss_type = type(self.loss)\n",
277277
" if not isinstance(self.valid_loss, loss_type):\n",
278278
" raise Exception(f'Please set valid_loss={type(self.loss).__name__}() when training with {type(self.loss).__name__}')\n",
279279
" if isinstance(self.loss, (losses.MQLoss, losses.HuberMQLoss)):\n",
280280
" if not isinstance(self.valid_loss, (losses.MQLoss, losses.HuberMQLoss)):\n",
281281
" raise Exception(f'Please set valid_loss to MQLoss() or HuberMQLoss() when training with {type(self.loss).__name__}')\n",
282-
" if isinstance(self.valid_loss, losses.IQLoss):\n",
282+
" if isinstance(self.valid_loss, (losses.IQLoss, losses.HuberIQLoss)):\n",
283283
" valid_loss_type = type(self.valid_loss)\n",
284284
" if not isinstance(self.loss, valid_loss_type):\n",
285285
" raise Exception(f'Please set loss={type(self.valid_loss).__name__}() when validating with {type(self.valid_loss).__name__}') \n",
@@ -425,7 +425,7 @@
425425
" )\n",
426426
" \n",
427427
" def _set_quantiles(self, quantiles=None):\n",
428-
" if quantiles is None and isinstance(self.loss, losses.IQLoss):\n",
428+
" if quantiles is None and isinstance(self.loss, (losses.IQLoss, losses.HuberIQLoss)):\n",
429429
" self.loss.update_quantile(q=[0.5])\n",
430430
" elif hasattr(self.loss, 'update_quantile') and callable(self.loss.update_quantile):\n",
431431
" self.loss.update_quantile(q=quantiles)\n",

nbs/common.model_checks.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@
146146
"# Tests a model against every loss function\n",
147147
"def check_loss_functions(model_class):\n",
148148
" loss_list = [losses.MAE(), losses.MSE(), losses.RMSE(), losses.MAPE(), losses.SMAPE(), losses.MASE(seasonality=7), \n",
149-
" losses.QuantileLoss(q=0.5), losses.MQLoss(), losses.IQLoss(), losses.DistributionLoss(\"Normal\"), \n",
149+
" losses.QuantileLoss(q=0.5), losses.MQLoss(), losses.IQLoss(), losses.HuberIQLoss(), losses.DistributionLoss(\"Normal\"), \n",
150150
" losses.DistributionLoss(\"StudentT\"), losses.DistributionLoss(\"Poisson\"), losses.DistributionLoss(\"NegativeBinomial\"), \n",
151151
" losses.DistributionLoss(\"Tweedie\", rho=1.5), losses.DistributionLoss(\"ISQF\"), losses.PMM(), losses.PMM(weighted=True), \n",
152152
" losses.GMM(), losses.GMM(weighted=True), losses.NBMM(), losses.NBMM(weighted=True), losses.HuberLoss(), \n",

nbs/core.ipynb

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@
8484
"\n",
8585
"from neuralforecast.common._base_model import DistributedConfig\n",
8686
"from neuralforecast.compat import SparkDataFrame\n",
87-
"from neuralforecast.losses.pytorch import IQLoss\n",
87+
"from neuralforecast.losses.pytorch import IQLoss, HuberIQLoss\n",
8888
"from neuralforecast.tsdataset import _FilesDataset, TimeSeriesDataset, LocalFilesTimeSeriesDataset\n",
8989
"from neuralforecast.models import (\n",
9090
" GRU, LSTM, RNN, TCN, DeepAR, DilatedRNN,\n",
@@ -718,7 +718,7 @@
718718
" if count_names[model_name] > 0:\n",
719719
" model_name += str(count_names[model_name])\n",
720720
"\n",
721-
" if add_level and (model.loss.outputsize_multiplier > 1 or isinstance(model.loss, IQLoss)):\n",
721+
" if add_level and (model.loss.outputsize_multiplier > 1 or isinstance(model.loss, (IQLoss, HuberIQLoss))):\n",
722722
" continue\n",
723723
"\n",
724724
" names.extend(model_name + n for n in model.loss.output_names)\n",
@@ -1052,7 +1052,7 @@
10521052
"\n",
10531053
" fcsts_list: List = []\n",
10541054
" for model in self.models:\n",
1055-
" if self._add_level and (model.loss.outputsize_multiplier > 1 or isinstance(model.loss, IQLoss)):\n",
1055+
" if self._add_level and (model.loss.outputsize_multiplier > 1 or isinstance(model.loss, (IQLoss, HuberIQLoss))):\n",
10561056
" continue\n",
10571057
"\n",
10581058
" model.fit(dataset=self.dataset,\n",
@@ -1687,7 +1687,7 @@
16871687
"\n",
16881688
" # Predict for every quantile or level if requested and the loss function supports it\n",
16891689
" # case 1: DistributionLoss and MixtureLosses\n",
1690-
" if quantiles_ is not None and not isinstance(model.loss, IQLoss) and hasattr(model.loss, 'update_quantile') and callable(model.loss.update_quantile):\n",
1690+
" if quantiles_ is not None and not isinstance(model.loss, (IQLoss, HuberIQLoss)) and hasattr(model.loss, 'update_quantile') and callable(model.loss.update_quantile):\n",
16911691
" model_fcsts = model.predict(dataset=dataset, quantiles = quantiles_, **data_kwargs)\n",
16921692
" fcsts_list.append(model_fcsts) \n",
16931693
" col_names = []\n",
@@ -1702,7 +1702,7 @@
17021702
" else:\n",
17031703
" cols.extend(col_names)\n",
17041704
" # case 2: IQLoss\n",
1705-
" elif quantiles_ is not None and isinstance(model.loss, IQLoss):\n",
1705+
" elif quantiles_ is not None and isinstance(model.loss, (IQLoss, HuberIQLoss)):\n",
17061706
" # IQLoss does not give monotonically increasing quantiles, so we apply a hack: compute all quantiles, and take the quantile over the quantiles\n",
17071707
" quantiles_iqloss = np.linspace(0.01, 0.99, 20)\n",
17081708
" fcsts_list_iqloss = []\n",

nbs/docs/capabilities/02_objectives.ipynb

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,9 @@
3030
"|[**Poisson**](../../losses.pytorch.html#distributionloss) |[**HuberQLoss**](../../losses.pytorch.html#huberized-quantile-loss)|\n",
3131
"|[**Negative Binomial**](../../losses.pytorch.html#distributionloss)|[**HuberMQLoss**](../../losses.pytorch.html#huberized-mqloss) |\n",
3232
"|[**Tweedie**](../../losses.pytorch.html#distributionloss) |[**IQLoss**](../../losses.pytorch.html#iqloss) |\n",
33-
"|[**PMM**](../../losses.pytorch.html#poisson-mixture-mesh-pmm) /[**GMM**](../../losses.pytorch.html#gaussian-mixture-mesh-gmm) / [**NBMM**](../../losses.pytorch.html#negative-binomial-mixture-mesh-nbmm) | [**ISQF**](../../losses.pytorch.html#isqf) | "
33+
"|[**PMM**](../../losses.pytorch.html#poisson-mixture-mesh-pmm) | [**HuberIQLoss**](../../losses.pytorch.html#huberized-iqloss)|\n",
34+
"|[**GMM**](../../losses.pytorch.html#gaussian-mixture-mesh-gmm) | [**ISQF**](../../losses.pytorch.html#isqf) |\n",
35+
"|[**NBMM**](../../losses.pytorch.html#negative-binomial-mixture-mesh-nbmm) | |"
3436
]
3537
}
3638
],

nbs/losses.pytorch.ipynb

Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1244,6 +1244,7 @@
12441244
" self._init_sampling_distribution(device)\n",
12451245
"\n",
12461246
" quantiles = self.sampling_distr.sample(sample_size)\n",
1247+
" self.q = quantiles.squeeze(-1)\n",
12471248
" self.has_sampled = True \n",
12481249
" self.has_predicted = False\n",
12491250
"\n",
@@ -4266,6 +4267,160 @@
42664267
"![](imgs_losses/hmq_loss.png)"
42674268
]
42684269
},
4270+
{
4271+
"cell_type": "markdown",
4272+
"id": "affe8b2f",
4273+
"metadata": {},
4274+
"source": [
4275+
"## Huberized IQLoss"
4276+
]
4277+
},
4278+
{
4279+
"cell_type": "code",
4280+
"execution_count": null,
4281+
"id": "31e71b0d",
4282+
"metadata": {},
4283+
"outputs": [],
4284+
"source": [
4285+
"#| export\n",
4286+
"class HuberIQLoss(HuberQLoss):\n",
4287+
" \"\"\"Implicit Huber Quantile Loss\n",
4288+
"\n",
4289+
" Computes the huberized quantile loss between `y` and `y_hat`, with the quantile `q` provided as an input to the network. \n",
4290+
" HuberIQLoss measures the deviation of a huberized quantile forecast.\n",
4291+
" By weighting the absolute deviation in a non symmetric way, the\n",
4292+
" loss pays more attention to under or over estimation.\n",
4293+
"\n",
4294+
" $$ \\mathrm{HuberQL}(\\\\mathbf{y}_{\\\\tau}, \\\\mathbf{\\hat{y}}^{(q)}_{\\\\tau}) = \n",
4295+
" (1-q)\\, L_{\\delta}(y_{\\\\tau},\\; \\hat{y}^{(q)}_{\\\\tau}) \\mathbb{1}\\{ \\hat{y}^{(q)}_{\\\\tau} \\geq y_{\\\\tau} \\} + \n",
4296+
" q\\, L_{\\delta}(y_{\\\\tau},\\; \\hat{y}^{(q)}_{\\\\tau}) \\mathbb{1}\\{ \\hat{y}^{(q)}_{\\\\tau} < y_{\\\\tau} \\} $$\n",
4297+
"\n",
4298+
" **Parameters:**<br>\n",
4299+
" `quantile_sampling`: str, default='uniform', sampling distribution used to sample the quantiles during training. Choose from ['uniform', 'beta']. <br>\n",
4300+
" `horizon_weight`: Tensor of size h, weight for each timestamp of the forecasting window. <br>\n",
4301+
" `delta`: float=1.0, Specifies the threshold at which to change between delta-scaled L1 and L2 loss.<br>\n",
4302+
"\n",
4303+
" **References:**<br>\n",
4304+
" [Gouttes, Adèle, Kashif Rasul, Mateusz Koren, Johannes Stephan, and Tofigh Naghibi, \"Probabilistic Time Series Forecasting with Implicit Quantile Networks\".](http://arxiv.org/abs/2107.03743)\n",
4305+
" [Huber Peter, J (1964). \"Robust Estimation of a Location Parameter\". Annals of Statistics](https://projecteuclid.org/journals/annals-of-mathematical-statistics/volume-35/issue-1/Robust-Estimation-of-a-Location-Parameter/10.1214/aoms/1177703732.full)<br>\n",
4306+
" [Roger Koenker and Gilbert Bassett, Jr., \"Regression Quantiles\".](https://www.jstor.org/stable/1913643)\n",
4307+
" \"\"\"\n",
4308+
" def __init__(self, cos_embedding_dim = 64, concentration0 = 1.0, concentration1 = 1.0, delta = 1.0, horizon_weight=None):\n",
4309+
" self.update_quantile()\n",
4310+
" super(HuberIQLoss, self).__init__(\n",
4311+
" q = self.q,\n",
4312+
" delta = delta,\n",
4313+
" horizon_weight=horizon_weight\n",
4314+
" )\n",
4315+
"\n",
4316+
" self.cos_embedding_dim = cos_embedding_dim\n",
4317+
" self.concentration0 = concentration0\n",
4318+
" self.concentration1 = concentration1\n",
4319+
" self.has_sampled = False\n",
4320+
" self.has_predicted = False\n",
4321+
"\n",
4322+
" self.quantile_layer = QuantileLayer(\n",
4323+
" num_output=1, cos_embedding_dim=self.cos_embedding_dim\n",
4324+
" )\n",
4325+
" self.output_layer = nn.Sequential(\n",
4326+
" nn.Linear(1, 1), nn.PReLU()\n",
4327+
" )\n",
4328+
" \n",
4329+
" def _sample_quantiles(self, sample_size, device):\n",
4330+
" if not self.has_sampled:\n",
4331+
" self._init_sampling_distribution(device)\n",
4332+
"\n",
4333+
" quantiles = self.sampling_distr.sample(sample_size)\n",
4334+
" self.q = quantiles.squeeze(-1)\n",
4335+
" self.has_sampled = True \n",
4336+
" self.has_predicted = False\n",
4337+
"\n",
4338+
" return quantiles\n",
4339+
" \n",
4340+
" def _init_sampling_distribution(self, device):\n",
4341+
" concentration0 = torch.tensor([self.concentration0],\n",
4342+
" device=device,\n",
4343+
" dtype=torch.float32)\n",
4344+
" concentration1 = torch.tensor([self.concentration1],\n",
4345+
" device=device,\n",
4346+
" dtype=torch.float32) \n",
4347+
" self.sampling_distr = Beta(concentration0 = concentration0,\n",
4348+
" concentration1 = concentration1)\n",
4349+
"\n",
4350+
" def update_quantile(self, q: List[float] = [0.5]):\n",
4351+
" self.q = q[0]\n",
4352+
" self.output_names = [f\"_ql{q[0]}\"]\n",
4353+
" self.has_predicted = True\n",
4354+
"\n",
4355+
" def domain_map(self, y_hat):\n",
4356+
" \"\"\"\n",
4357+
" Adds IQN network to output of network\n",
4358+
"\n",
4359+
" Input shapes to this function:\n",
4360+
" \n",
4361+
" Univariate: y_hat = [B, h, 1] \n",
4362+
" Multivariate: y_hat = [B, h, N]\n",
4363+
" \"\"\"\n",
4364+
" if self.eval() and self.has_predicted:\n",
4365+
" quantiles = torch.full(size=y_hat.shape, \n",
4366+
" fill_value=self.q,\n",
4367+
" device=y_hat.device,\n",
4368+
" dtype=y_hat.dtype) \n",
4369+
" quantiles = quantiles.unsqueeze(-1) \n",
4370+
" else:\n",
4371+
" quantiles = self._sample_quantiles(sample_size=y_hat.shape,\n",
4372+
" device=y_hat.device)\n",
4373+
"\n",
4374+
" # Embed the quantiles and add to y_hat\n",
4375+
" emb_taus = self.quantile_layer(quantiles)\n",
4376+
" emb_inputs = y_hat.unsqueeze(-1) * (1.0 + emb_taus)\n",
4377+
" emb_outputs = self.output_layer(emb_inputs)\n",
4378+
" \n",
4379+
" # Domain map\n",
4380+
" y_hat = emb_outputs.squeeze(-1)\n",
4381+
"\n",
4382+
" return y_hat\n"
4383+
]
4384+
},
4385+
{
4386+
"cell_type": "code",
4387+
"execution_count": null,
4388+
"id": "9ccf9024",
4389+
"metadata": {},
4390+
"outputs": [],
4391+
"source": [
4392+
"show_doc(HuberIQLoss, name='HuberIQLoss.__init__', title_level=3)"
4393+
]
4394+
},
4395+
{
4396+
"cell_type": "code",
4397+
"execution_count": null,
4398+
"id": "23a84e21",
4399+
"metadata": {},
4400+
"outputs": [],
4401+
"source": [
4402+
"show_doc(HuberIQLoss.__call__, name='HuberIQLoss.__call__', title_level=3)"
4403+
]
4404+
},
4405+
{
4406+
"cell_type": "code",
4407+
"execution_count": null,
4408+
"id": "db4a68dc",
4409+
"metadata": {},
4410+
"outputs": [],
4411+
"source": [
4412+
"# | hide\n",
4413+
"# Unit tests\n",
4414+
"# Check that default quantile is set to 0.5 at initialization\n",
4415+
"check = HuberIQLoss()\n",
4416+
"test_eq(check.q, 0.5)\n",
4417+
"\n",
4418+
"# Check that quantiles are correctly updated - prediction\n",
4419+
"check = HuberIQLoss()\n",
4420+
"check.update_quantile([0.7])\n",
4421+
"test_eq(check.q, 0.7)"
4422+
]
4423+
},
42694424
{
42704425
"attachments": {},
42714426
"cell_type": "markdown",

neuralforecast/_modidx.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,18 @@
317317
'neuralforecast/losses/pytorch.py'),
318318
'neuralforecast.losses.pytorch.GMM.update_quantile': ( 'losses.pytorch.html#gmm.update_quantile',
319319
'neuralforecast/losses/pytorch.py'),
320+
'neuralforecast.losses.pytorch.HuberIQLoss': ( 'losses.pytorch.html#huberiqloss',
321+
'neuralforecast/losses/pytorch.py'),
322+
'neuralforecast.losses.pytorch.HuberIQLoss.__init__': ( 'losses.pytorch.html#huberiqloss.__init__',
323+
'neuralforecast/losses/pytorch.py'),
324+
'neuralforecast.losses.pytorch.HuberIQLoss._init_sampling_distribution': ( 'losses.pytorch.html#huberiqloss._init_sampling_distribution',
325+
'neuralforecast/losses/pytorch.py'),
326+
'neuralforecast.losses.pytorch.HuberIQLoss._sample_quantiles': ( 'losses.pytorch.html#huberiqloss._sample_quantiles',
327+
'neuralforecast/losses/pytorch.py'),
328+
'neuralforecast.losses.pytorch.HuberIQLoss.domain_map': ( 'losses.pytorch.html#huberiqloss.domain_map',
329+
'neuralforecast/losses/pytorch.py'),
330+
'neuralforecast.losses.pytorch.HuberIQLoss.update_quantile': ( 'losses.pytorch.html#huberiqloss.update_quantile',
331+
'neuralforecast/losses/pytorch.py'),
320332
'neuralforecast.losses.pytorch.HuberLoss': ( 'losses.pytorch.html#huberloss',
321333
'neuralforecast/losses/pytorch.py'),
322334
'neuralforecast.losses.pytorch.HuberLoss.__call__': ( 'losses.pytorch.html#huberloss.__call__',

neuralforecast/common/_base_model.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ def __init__(
233233
)
234234

235235
# Protections for loss functions
236-
if isinstance(self.loss, (losses.IQLoss)):
236+
if isinstance(self.loss, (losses.IQLoss, losses.HuberIQLoss)):
237237
loss_type = type(self.loss)
238238
if not isinstance(self.valid_loss, loss_type):
239239
raise Exception(
@@ -244,7 +244,7 @@ def __init__(
244244
raise Exception(
245245
f"Please set valid_loss to MQLoss() or HuberMQLoss() when training with {type(self.loss).__name__}"
246246
)
247-
if isinstance(self.valid_loss, losses.IQLoss):
247+
if isinstance(self.valid_loss, (losses.IQLoss, losses.HuberIQLoss)):
248248
valid_loss_type = type(self.valid_loss)
249249
if not isinstance(self.loss, valid_loss_type):
250250
raise Exception(
@@ -415,7 +415,9 @@ def _get_temporal_exogenous_cols(self, temporal_cols):
415415
)
416416

417417
def _set_quantiles(self, quantiles=None):
418-
if quantiles is None and isinstance(self.loss, losses.IQLoss):
418+
if quantiles is None and isinstance(
419+
self.loss, (losses.IQLoss, losses.HuberIQLoss)
420+
):
419421
self.loss.update_quantile(q=[0.5])
420422
elif hasattr(self.loss, "update_quantile") and callable(
421423
self.loss.update_quantile

neuralforecast/common/_model_checks.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ def check_loss_functions(model_class):
131131
losses.QuantileLoss(q=0.5),
132132
losses.MQLoss(),
133133
losses.IQLoss(),
134+
losses.HuberIQLoss(),
134135
losses.DistributionLoss("Normal"),
135136
losses.DistributionLoss("StudentT"),
136137
losses.DistributionLoss("Poisson"),

neuralforecast/core.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828

2929
from .common._base_model import DistributedConfig
3030
from .compat import SparkDataFrame
31-
from .losses.pytorch import IQLoss
31+
from .losses.pytorch import IQLoss, HuberIQLoss
3232
from neuralforecast.tsdataset import (
3333
_FilesDataset,
3434
TimeSeriesDataset,
@@ -673,7 +673,8 @@ def _get_model_names(self, add_level=False) -> List[str]:
673673
model_name += str(count_names[model_name])
674674

675675
if add_level and (
676-
model.loss.outputsize_multiplier > 1 or isinstance(model.loss, IQLoss)
676+
model.loss.outputsize_multiplier > 1
677+
or isinstance(model.loss, (IQLoss, HuberIQLoss))
677678
):
678679
continue
679680

@@ -1029,7 +1030,8 @@ def _no_refit_cross_validation(
10291030
fcsts_list: List = []
10301031
for model in self.models:
10311032
if self._add_level and (
1032-
model.loss.outputsize_multiplier > 1 or isinstance(model.loss, IQLoss)
1033+
model.loss.outputsize_multiplier > 1
1034+
or isinstance(model.loss, (IQLoss, HuberIQLoss))
10331035
):
10341036
continue
10351037

@@ -1707,7 +1709,7 @@ def _generate_forecasts(
17071709
# case 1: DistributionLoss and MixtureLosses
17081710
if (
17091711
quantiles_ is not None
1710-
and not isinstance(model.loss, IQLoss)
1712+
and not isinstance(model.loss, (IQLoss, HuberIQLoss))
17111713
and hasattr(model.loss, "update_quantile")
17121714
and callable(model.loss.update_quantile)
17131715
):
@@ -1733,7 +1735,9 @@ def _generate_forecasts(
17331735
else:
17341736
cols.extend(col_names)
17351737
# case 2: IQLoss
1736-
elif quantiles_ is not None and isinstance(model.loss, IQLoss):
1738+
elif quantiles_ is not None and isinstance(
1739+
model.loss, (IQLoss, HuberIQLoss)
1740+
):
17371741
# IQLoss does not give monotonically increasing quantiles, so we apply a hack: compute all quantiles, and take the quantile over the quantiles
17381742
quantiles_iqloss = np.linspace(0.01, 0.99, 20)
17391743
fcsts_list_iqloss = []

0 commit comments

Comments
 (0)