|
1244 | 1244 | " self._init_sampling_distribution(device)\n", |
1245 | 1245 | "\n", |
1246 | 1246 | " quantiles = self.sampling_distr.sample(sample_size)\n", |
| 1247 | + " self.q = quantiles.squeeze(-1)\n", |
1247 | 1248 | " self.has_sampled = True \n", |
1248 | 1249 | " self.has_predicted = False\n", |
1249 | 1250 | "\n", |
|
4266 | 4267 | "" |
4267 | 4268 | ] |
4268 | 4269 | }, |
| 4270 | + { |
| 4271 | + "cell_type": "markdown", |
| 4272 | + "id": "affe8b2f", |
| 4273 | + "metadata": {}, |
| 4274 | + "source": [ |
| 4275 | + "## Huberized IQLoss" |
| 4276 | + ] |
| 4277 | + }, |
| 4278 | + { |
| 4279 | + "cell_type": "code", |
| 4280 | + "execution_count": null, |
| 4281 | + "id": "31e71b0d", |
| 4282 | + "metadata": {}, |
| 4283 | + "outputs": [], |
| 4284 | + "source": [ |
| 4285 | + "#| export\n", |
| 4286 | + "class HuberIQLoss(HuberQLoss):\n", |
| 4287 | + " \"\"\"Implicit Huber Quantile Loss\n", |
| 4288 | + "\n", |
| 4289 | + " Computes the huberized quantile loss between `y` and `y_hat`, with the quantile `q` provided as an input to the network. \n", |
| 4290 | + " HuberIQLoss measures the deviation of a huberized quantile forecast.\n", |
| 4291 | + " By weighting the absolute deviation in a non symmetric way, the\n", |
| 4292 | + " loss pays more attention to under or over estimation.\n", |
| 4293 | + "\n", |
| 4294 | + " $$ \\mathrm{HuberQL}(\\\\mathbf{y}_{\\\\tau}, \\\\mathbf{\\hat{y}}^{(q)}_{\\\\tau}) = \n", |
| 4295 | + " (1-q)\\, L_{\\delta}(y_{\\\\tau},\\; \\hat{y}^{(q)}_{\\\\tau}) \\mathbb{1}\\{ \\hat{y}^{(q)}_{\\\\tau} \\geq y_{\\\\tau} \\} + \n", |
| 4296 | + " q\\, L_{\\delta}(y_{\\\\tau},\\; \\hat{y}^{(q)}_{\\\\tau}) \\mathbb{1}\\{ \\hat{y}^{(q)}_{\\\\tau} < y_{\\\\tau} \\} $$\n", |
| 4297 | + "\n", |
| 4298 | + " **Parameters:**<br>\n", |
| 4299 | + " `quantile_sampling`: str, default='uniform', sampling distribution used to sample the quantiles during training. Choose from ['uniform', 'beta']. <br>\n", |
| 4300 | + " `horizon_weight`: Tensor of size h, weight for each timestamp of the forecasting window. <br>\n", |
| 4301 | + " `delta`: float=1.0, Specifies the threshold at which to change between delta-scaled L1 and L2 loss.<br>\n", |
| 4302 | + "\n", |
| 4303 | + " **References:**<br>\n", |
| 4304 | + " [Gouttes, Adèle, Kashif Rasul, Mateusz Koren, Johannes Stephan, and Tofigh Naghibi, \"Probabilistic Time Series Forecasting with Implicit Quantile Networks\".](http://arxiv.org/abs/2107.03743)\n", |
| 4305 | + " [Huber Peter, J (1964). \"Robust Estimation of a Location Parameter\". Annals of Statistics](https://projecteuclid.org/journals/annals-of-mathematical-statistics/volume-35/issue-1/Robust-Estimation-of-a-Location-Parameter/10.1214/aoms/1177703732.full)<br>\n", |
| 4306 | + " [Roger Koenker and Gilbert Bassett, Jr., \"Regression Quantiles\".](https://www.jstor.org/stable/1913643)\n", |
| 4307 | + " \"\"\"\n", |
| 4308 | + " def __init__(self, cos_embedding_dim = 64, concentration0 = 1.0, concentration1 = 1.0, delta = 1.0, horizon_weight=None):\n", |
| 4309 | + " self.update_quantile()\n", |
| 4310 | + " super(HuberIQLoss, self).__init__(\n", |
| 4311 | + " q = self.q,\n", |
| 4312 | + " delta = delta,\n", |
| 4313 | + " horizon_weight=horizon_weight\n", |
| 4314 | + " )\n", |
| 4315 | + "\n", |
| 4316 | + " self.cos_embedding_dim = cos_embedding_dim\n", |
| 4317 | + " self.concentration0 = concentration0\n", |
| 4318 | + " self.concentration1 = concentration1\n", |
| 4319 | + " self.has_sampled = False\n", |
| 4320 | + " self.has_predicted = False\n", |
| 4321 | + "\n", |
| 4322 | + " self.quantile_layer = QuantileLayer(\n", |
| 4323 | + " num_output=1, cos_embedding_dim=self.cos_embedding_dim\n", |
| 4324 | + " )\n", |
| 4325 | + " self.output_layer = nn.Sequential(\n", |
| 4326 | + " nn.Linear(1, 1), nn.PReLU()\n", |
| 4327 | + " )\n", |
| 4328 | + " \n", |
| 4329 | + " def _sample_quantiles(self, sample_size, device):\n", |
| 4330 | + " if not self.has_sampled:\n", |
| 4331 | + " self._init_sampling_distribution(device)\n", |
| 4332 | + "\n", |
| 4333 | + " quantiles = self.sampling_distr.sample(sample_size)\n", |
| 4334 | + " self.q = quantiles.squeeze(-1)\n", |
| 4335 | + " self.has_sampled = True \n", |
| 4336 | + " self.has_predicted = False\n", |
| 4337 | + "\n", |
| 4338 | + " return quantiles\n", |
| 4339 | + " \n", |
| 4340 | + " def _init_sampling_distribution(self, device):\n", |
| 4341 | + " concentration0 = torch.tensor([self.concentration0],\n", |
| 4342 | + " device=device,\n", |
| 4343 | + " dtype=torch.float32)\n", |
| 4344 | + " concentration1 = torch.tensor([self.concentration1],\n", |
| 4345 | + " device=device,\n", |
| 4346 | + " dtype=torch.float32) \n", |
| 4347 | + " self.sampling_distr = Beta(concentration0 = concentration0,\n", |
| 4348 | + " concentration1 = concentration1)\n", |
| 4349 | + "\n", |
| 4350 | + " def update_quantile(self, q: List[float] = [0.5]):\n", |
| 4351 | + " self.q = q[0]\n", |
| 4352 | + " self.output_names = [f\"_ql{q[0]}\"]\n", |
| 4353 | + " self.has_predicted = True\n", |
| 4354 | + "\n", |
| 4355 | + " def domain_map(self, y_hat):\n", |
| 4356 | + " \"\"\"\n", |
| 4357 | + " Adds IQN network to output of network\n", |
| 4358 | + "\n", |
| 4359 | + " Input shapes to this function:\n", |
| 4360 | + " \n", |
| 4361 | + " Univariate: y_hat = [B, h, 1] \n", |
| 4362 | + " Multivariate: y_hat = [B, h, N]\n", |
| 4363 | + " \"\"\"\n", |
| 4364 | + " if self.eval() and self.has_predicted:\n", |
| 4365 | + " quantiles = torch.full(size=y_hat.shape, \n", |
| 4366 | + " fill_value=self.q,\n", |
| 4367 | + " device=y_hat.device,\n", |
| 4368 | + " dtype=y_hat.dtype) \n", |
| 4369 | + " quantiles = quantiles.unsqueeze(-1) \n", |
| 4370 | + " else:\n", |
| 4371 | + " quantiles = self._sample_quantiles(sample_size=y_hat.shape,\n", |
| 4372 | + " device=y_hat.device)\n", |
| 4373 | + "\n", |
| 4374 | + " # Embed the quantiles and add to y_hat\n", |
| 4375 | + " emb_taus = self.quantile_layer(quantiles)\n", |
| 4376 | + " emb_inputs = y_hat.unsqueeze(-1) * (1.0 + emb_taus)\n", |
| 4377 | + " emb_outputs = self.output_layer(emb_inputs)\n", |
| 4378 | + " \n", |
| 4379 | + " # Domain map\n", |
| 4380 | + " y_hat = emb_outputs.squeeze(-1)\n", |
| 4381 | + "\n", |
| 4382 | + " return y_hat\n" |
| 4383 | + ] |
| 4384 | + }, |
| 4385 | + { |
| 4386 | + "cell_type": "code", |
| 4387 | + "execution_count": null, |
| 4388 | + "id": "9ccf9024", |
| 4389 | + "metadata": {}, |
| 4390 | + "outputs": [], |
| 4391 | + "source": [ |
| 4392 | + "show_doc(HuberIQLoss, name='HuberIQLoss.__init__', title_level=3)" |
| 4393 | + ] |
| 4394 | + }, |
| 4395 | + { |
| 4396 | + "cell_type": "code", |
| 4397 | + "execution_count": null, |
| 4398 | + "id": "23a84e21", |
| 4399 | + "metadata": {}, |
| 4400 | + "outputs": [], |
| 4401 | + "source": [ |
| 4402 | + "show_doc(HuberIQLoss.__call__, name='HuberIQLoss.__call__', title_level=3)" |
| 4403 | + ] |
| 4404 | + }, |
| 4405 | + { |
| 4406 | + "cell_type": "code", |
| 4407 | + "execution_count": null, |
| 4408 | + "id": "db4a68dc", |
| 4409 | + "metadata": {}, |
| 4410 | + "outputs": [], |
| 4411 | + "source": [ |
| 4412 | + "# | hide\n", |
| 4413 | + "# Unit tests\n", |
| 4414 | + "# Check that default quantile is set to 0.5 at initialization\n", |
| 4415 | + "check = HuberIQLoss()\n", |
| 4416 | + "test_eq(check.q, 0.5)\n", |
| 4417 | + "\n", |
| 4418 | + "# Check that quantiles are correctly updated - prediction\n", |
| 4419 | + "check = HuberIQLoss()\n", |
| 4420 | + "check.update_quantile([0.7])\n", |
| 4421 | + "test_eq(check.q, 0.7)" |
| 4422 | + ] |
| 4423 | + }, |
4269 | 4424 | { |
4270 | 4425 | "attachments": {}, |
4271 | 4426 | "cell_type": "markdown", |
|
0 commit comments