|
| 1 | +{ |
| 2 | + "cells": [ |
| 3 | + { |
| 4 | + "cell_type": "markdown", |
| 5 | + "id": "e960490f", |
| 6 | + "metadata": {}, |
| 7 | + "source": [ |
| 8 | + "# Introduction to SageMaker JumpStart - Time Series Forecasting with Chronos" |
| 9 | + ] |
| 10 | + }, |
| 11 | + { |
| 12 | + "cell_type": "markdown", |
| 13 | + "id": "c1027fd6", |
| 14 | + "metadata": {}, |
| 15 | + "source": [ |
| 16 | + "---\n", |
| 17 | + "\n", |
| 18 | + "This notebook's CI test result for us-west-2 is as follows. CI test results in other regions can be found at the end of the notebook. \n", |
| 19 | + "\n", |
| 20 | + "\n", |
| 21 | + "\n", |
| 22 | + "---" |
| 23 | + ] |
| 24 | + }, |
| 25 | + { |
| 26 | + "cell_type": "markdown", |
| 27 | + "id": "6ecc10fd", |
| 28 | + "metadata": {}, |
| 29 | + "source": [ |
| 30 | + "In this demo notebook, we demonstrate how to use the SageMaker Python SDK to deploy a SageMaker JumpStart time series forecasting model and invoke the endpoint." |
| 31 | + ] |
| 32 | + }, |
| 33 | + { |
| 34 | + "cell_type": "markdown", |
| 35 | + "id": "2b4e4586", |
| 36 | + "metadata": {}, |
| 37 | + "source": [ |
| 38 | + "## Setup\n", |
| 39 | + "First, upgrade to the latest sagemaker SDK to ensure all available models are deployable." |
| 40 | + ] |
| 41 | + }, |
| 42 | + { |
| 43 | + "cell_type": "code", |
| 44 | + "execution_count": 1, |
| 45 | + "id": "9f1c969a", |
| 46 | + "metadata": {}, |
| 47 | + "outputs": [ |
| 48 | + { |
| 49 | + "name": "stdout", |
| 50 | + "output_type": "stream", |
| 51 | + "text": [ |
| 52 | + "Note: you may need to restart the kernel to use updated packages.\n" |
| 53 | + ] |
| 54 | + } |
| 55 | + ], |
| 56 | + "source": [ |
| 57 | + "%pip install sagemaker --upgrade --quiet" |
| 58 | + ] |
| 59 | + }, |
| 60 | + { |
| 61 | + "cell_type": "markdown", |
| 62 | + "id": "b958b25b", |
| 63 | + "metadata": {}, |
| 64 | + "source": [ |
| 65 | + "Select the desired model to deploy. The provided dropdown filters all time series forecasting models available in SageMaker JumpStart." |
| 66 | + ] |
| 67 | + }, |
| 68 | + { |
| 69 | + "cell_type": "code", |
| 70 | + "execution_count": 2, |
| 71 | + "id": "050856e9", |
| 72 | + "metadata": {}, |
| 73 | + "outputs": [ |
| 74 | + { |
| 75 | + "name": "stdout", |
| 76 | + "output_type": "stream", |
| 77 | + "text": [ |
| 78 | + "sagemaker.config INFO - Not applying SDK defaults from location: /etc/xdg/sagemaker/config.yaml\n", |
| 79 | + "sagemaker.config INFO - Not applying SDK defaults from location: /home/shchuro/.config/sagemaker/config.yaml\n" |
| 80 | + ] |
| 81 | + }, |
| 82 | + { |
| 83 | + "data": { |
| 84 | + "application/vnd.jupyter.widget-view+json": { |
| 85 | + "model_id": "728a6ed928854101b1fc0616bc6ea4b3", |
| 86 | + "version_major": 2, |
| 87 | + "version_minor": 0 |
| 88 | + }, |
| 89 | + "text/plain": [ |
| 90 | + "Dropdown(description='Select a JumpStart time series forecasting model:', index=2, layout=Layout(width='max-co…" |
| 91 | + ] |
| 92 | + }, |
| 93 | + "metadata": {}, |
| 94 | + "output_type": "display_data" |
| 95 | + } |
| 96 | + ], |
| 97 | + "source": [ |
| 98 | + "from ipywidgets import Dropdown\n", |
| 99 | + "from sagemaker.jumpstart.notebook_utils import list_jumpstart_models\n", |
| 100 | + "\n", |
| 101 | + "dropdown = Dropdown(\n", |
| 102 | + " options=list_jumpstart_models(filter=\"task == forecasting\"),\n", |
| 103 | + " value=\"autogluon-forecasting-chronos-t5-small\",\n", |
| 104 | + " description=\"Select a JumpStart time series forecasting model:\",\n", |
| 105 | + " style={\"description_width\": \"initial\"},\n", |
| 106 | + " layout={\"width\": \"max-content\"},\n", |
| 107 | + ")\n", |
| 108 | + "display(dropdown)" |
| 109 | + ] |
| 110 | + }, |
| 111 | + { |
| 112 | + "cell_type": "code", |
| 113 | + "execution_count": 3, |
| 114 | + "id": "1b257502", |
| 115 | + "metadata": {}, |
| 116 | + "outputs": [], |
| 117 | + "source": [ |
| 118 | + "model_id = dropdown.value\n", |
| 119 | + "model_version = \"*\"" |
| 120 | + ] |
| 121 | + }, |
| 122 | + { |
| 123 | + "cell_type": "markdown", |
| 124 | + "id": "6f2fc5d2", |
| 125 | + "metadata": {}, |
| 126 | + "source": [ |
| 127 | + "## Deploy model\n", |
| 128 | + "\n", |
| 129 | + "Create a `JumpStartModel` object, which initializes default model configurations conditioned on the selected instance type. JumpStart already sets a default instance type, but you can deploy the model on other instance types by passing `instance_type` to the `JumpStartModel` class." |
| 130 | + ] |
| 131 | + }, |
| 132 | + { |
| 133 | + "cell_type": "code", |
| 134 | + "execution_count": 6, |
| 135 | + "id": "a4c5bb0e", |
| 136 | + "metadata": {}, |
| 137 | + "outputs": [], |
| 138 | + "source": [ |
| 139 | + "from sagemaker.jumpstart.model import JumpStartModel\n", |
| 140 | + "\n", |
| 141 | + "model = JumpStartModel(model_id=model_id, model_version=model_version)" |
| 142 | + ] |
| 143 | + }, |
| 144 | + { |
| 145 | + "cell_type": "markdown", |
| 146 | + "id": "67eeeab7", |
| 147 | + "metadata": {}, |
| 148 | + "source": [ |
| 149 | + "You can now deploy the model using SageMaker JumpStart. The deployment might take a few minutes. " |
| 150 | + ] |
| 151 | + }, |
| 152 | + { |
| 153 | + "cell_type": "code", |
| 154 | + "execution_count": 7, |
| 155 | + "id": "3c7726a9", |
| 156 | + "metadata": {}, |
| 157 | + "outputs": [ |
| 158 | + { |
| 159 | + "name": "stdout", |
| 160 | + "output_type": "stream", |
| 161 | + "text": [ |
| 162 | + "----------!" |
| 163 | + ] |
| 164 | + } |
| 165 | + ], |
| 166 | + "source": [ |
| 167 | + "predictor = model.deploy()" |
| 168 | + ] |
| 169 | + }, |
| 170 | + { |
| 171 | + "cell_type": "markdown", |
| 172 | + "id": "a4849068", |
| 173 | + "metadata": {}, |
| 174 | + "source": [ |
| 175 | + "## Invoke the endpoint\n", |
| 176 | + "\n", |
| 177 | + "This section demonstrates how to invoke the endpoint using example payloads that are retrieved programmatically from the `JumpStartModel` object. You can replace these example payloads with your own payloads." |
| 178 | + ] |
| 179 | + }, |
| 180 | + { |
| 181 | + "cell_type": "code", |
| 182 | + "execution_count": 8, |
| 183 | + "id": "27637734", |
| 184 | + "metadata": {}, |
| 185 | + "outputs": [], |
| 186 | + "source": [ |
| 187 | + "from pprint import pformat\n", |
| 188 | + "\n", |
| 189 | + "\n", |
| 190 | + "def nested_round(data, decimals=3):\n", |
| 191 | + " \"\"\"Round numbers, including nested dicts and list.\"\"\"\n", |
| 192 | + " if isinstance(data, float):\n", |
| 193 | + " return round(data, decimals)\n", |
| 194 | + " elif isinstance(data, list):\n", |
| 195 | + " return [nested_round(item, decimals) for item in data]\n", |
| 196 | + " elif isinstance(data, dict):\n", |
| 197 | + " return {key: nested_round(value, decimals) for key, value in data.items()}\n", |
| 198 | + " else:\n", |
| 199 | + " return data\n", |
| 200 | + "\n", |
| 201 | + "\n", |
| 202 | + "def pretty_format(data):\n", |
| 203 | + " return pformat(nested_round(data), width=150, sort_dicts=False)" |
| 204 | + ] |
| 205 | + }, |
| 206 | + { |
| 207 | + "cell_type": "code", |
| 208 | + "execution_count": 15, |
| 209 | + "id": "775c2325", |
| 210 | + "metadata": {}, |
| 211 | + "outputs": [ |
| 212 | + { |
| 213 | + "name": "stdout", |
| 214 | + "output_type": "stream", |
| 215 | + "text": [ |
| 216 | + "Input:\n", |
| 217 | + " {'inputs': [{'target': [0.0, 4.0, 5.0, 1.5, -3.0, -5.0, -3.0, 1.5, 5.0, 4.0, 0.0, -4.0, -5.0, -1.5, 3.0, 5.0, 3.0, -1.5, -5.0, -4.0]}],\n", |
| 218 | + " 'parameters': {'prediction_length': 10}}\n", |
| 219 | + "\n", |
| 220 | + "Output:\n", |
| 221 | + " {'predictions': [{'mean': [-0.488, 3.101, 3.086, 0.436, -2.867, -3.924, -2.258, 0.686, 2.456, 2.089],\n", |
| 222 | + " '0.1': [-4.331, 1.351, -0.04, -3.467, -5.713, -5.051, -5.056, -3.247, -1.907, -1.898],\n", |
| 223 | + " '0.5': [0.0, 3.507, 4.012, 0.0, -3.003, -4.903, -3.015, 1.501, 3.003, 3.003],\n", |
| 224 | + " '0.9': [1.652, 4.997, 4.997, 4.11, 0.0, -2.215, 1.602, 4.997, 4.997, 4.997]}]}\n", |
| 225 | + "\n", |
| 226 | + "===============\n", |
| 227 | + "\n", |
| 228 | + "Input:\n", |
| 229 | + " {'inputs': [{'target': [1.0, 2.0, 3.0, 2.0, 0.5, 2.0, 3.0, 2.0, 1.0], 'item_id': 'product_A', 'start': '2024-01-01T01:00:00'},\n", |
| 230 | + " {'target': [5.4, 3.0, 3.0, 2.0, 1.5, 2.0, -1.0], 'item_id': 'product_B', 'start': '2024-02-02T03:00:00'}],\n", |
| 231 | + " 'parameters': {'prediction_length': 5, 'freq': '1h', 'quantile_levels': [0.05, 0.5, 0.95], 'num_samples': 30, 'batch_size': 2}}\n", |
| 232 | + "\n", |
| 233 | + "Output:\n", |
| 234 | + " {'predictions': [{'mean': [1.731, 1.498, 1.764, 1.632, 1.465],\n", |
| 235 | + " '0.05': [0.224, 0.497, 0.224, 0.0, 0.0],\n", |
| 236 | + " '0.5': [0.995, 0.995, 1.25, 1.499, 0.995],\n", |
| 237 | + " '0.95': [4.005, 2.997, 3.278, 3.999, 3.544],\n", |
| 238 | + " 'item_id': 'product_A',\n", |
| 239 | + " 'start': '2024-01-01T10:00:00'},\n", |
| 240 | + " {'mean': [0.084, 0.916, 0.384, 1.205, 1.481],\n", |
| 241 | + " '0.05': [-1.273, -0.726, -1.537, -0.358, -0.863],\n", |
| 242 | + " '0.5': [0.0, 0.872, 0.0, 1.012, 1.059],\n", |
| 243 | + " '0.95': [2.006, 3.109, 2.552, 3.0, 4.887],\n", |
| 244 | + " 'item_id': 'product_B',\n", |
| 245 | + " 'start': '2024-02-02T10:00:00'}]}\n", |
| 246 | + "\n", |
| 247 | + "===============\n", |
| 248 | + "\n" |
| 249 | + ] |
| 250 | + } |
| 251 | + ], |
| 252 | + "source": [ |
| 253 | + "for payload in model.retrieve_all_examples():\n", |
| 254 | + " response = predictor.predict(payload.body)\n", |
| 255 | + " print(\"Input:\\n\", pretty_format(payload.body), end=\"\\n\\n\")\n", |
| 256 | + " print(\"Output:\\n\", pretty_format(response))\n", |
| 257 | + " print(\"\\n===============\\n\")" |
| 258 | + ] |
| 259 | + }, |
| 260 | + { |
| 261 | + "cell_type": "markdown", |
| 262 | + "id": "335471d8", |
| 263 | + "metadata": {}, |
| 264 | + "source": [ |
| 265 | + "The payload for Chronos models must be structured as follows.\n", |
| 266 | + "* **inputs** (required): List with at most 64 time series that need to be forecasted. Each time series is represented by a dictionary with the following keys:\n", |
| 267 | + " * **target** (required): List of observed numeric time series values. \n", |
| 268 | + " - It is recommended that each time series contains at least 30 observations.\n", |
| 269 | + " - If any time series contains fewer than 5 observations, an error will be raised.\n", |
| 270 | + " * **item_id**: String that uniquely identifies each time series. \n", |
| 271 | + " - If provided, the ID must be unique for each time series.\n", |
| 272 | + " - If provided, then the endpoint response will also include the **item_id** field for each forecast.\n", |
| 273 | + " * **start**: Timestamp of the first time series observation in ISO format (`YYYY-MM-DD` or `YYYY-MM-DDThh:mm:ss`). \n", |
| 274 | + " - If **start** field is provided, then **freq** must also be provided as part of **parameters**.\n", |
| 275 | + " - If provided, then the endpoint response will also include the **start** field indicating the first timestamp of each forecast.\n", |
| 276 | + "* **parameters**: Optional parameters to configure the model.\n", |
| 277 | + " * **prediction_length**: Integer corresponding to the number of future time series values that need to be predicted. \n", |
| 278 | + " - Recommended to keep prediction_length <= 64 since larger values will result in inaccurate quantile forecasts. Values above 1000 will raise an error.\n", |
| 279 | + " * **quantile_levels**: List of floats in range (0, 1) specifying which quantiles should should be included in the probabilistic forecast. Defaults to `[0.1, 0.5, 0.9]`. \n", |
| 280 | + " * **freq**: Frequency of the time series observations in [pandas-compatible format](https://pandas.pydata.org/pandas-docs/stable/user_guide/timeseries.html#offset-aliases). For example, `1h` for hourly data or `2W` for bi-weekly data. \n", |
| 281 | + " - If **freq** is provided, then **start** must also be provided for each time series in **inputs**.\n", |
| 282 | + " * **num_samples**: Number of sample trajectories generated by the Chronos model during inference. Larger values may improve accuracy but increase memory consumption and slow down inference. Defaults to `20`.\n", |
| 283 | + " * **batch_size**: Number of time series processed in parallel by the model. Larger values speed up inference but may lead to out of memory errors.\n", |
| 284 | + "\n", |
| 285 | + "All keys not marked with (required) are optional.\n", |
| 286 | + "\n", |
| 287 | + "The endpoint response contains the probabilistic (quantile) forecast for each time series included in the request." |
| 288 | + ] |
| 289 | + }, |
| 290 | + { |
| 291 | + "cell_type": "markdown", |
| 292 | + "id": "aa5c2c0a", |
| 293 | + "metadata": {}, |
| 294 | + "source": [ |
| 295 | + "## Clean up the endpoint\n", |
| 296 | + "Don't forget to clean up resources when finished to avoid unnecessary charges." |
| 297 | + ] |
| 298 | + }, |
| 299 | + { |
| 300 | + "cell_type": "code", |
| 301 | + "execution_count": 16, |
| 302 | + "id": "b1a4059e", |
| 303 | + "metadata": {}, |
| 304 | + "outputs": [], |
| 305 | + "source": [ |
| 306 | + "predictor.delete_predictor()" |
| 307 | + ] |
| 308 | + }, |
| 309 | + { |
| 310 | + "cell_type": "markdown", |
| 311 | + "id": "1b1d70fc", |
| 312 | + "metadata": {}, |
| 313 | + "source": [ |
| 314 | + "## Notebook CI Test Results\n", |
| 315 | + "\n", |
| 316 | + "This notebook was tested in multiple regions. The test results are as follows, except for us-west-2 which is shown at the top of the notebook.\n", |
| 317 | + "\n", |
| 318 | + "\n", |
| 319 | + "\n", |
| 320 | + "\n", |
| 321 | + "\n", |
| 322 | + "\n", |
| 323 | + "\n", |
| 324 | + "\n", |
| 325 | + "\n", |
| 326 | + "\n", |
| 327 | + "\n", |
| 328 | + "\n", |
| 329 | + "\n", |
| 330 | + "\n", |
| 331 | + "\n", |
| 332 | + "\n", |
| 333 | + "\n", |
| 334 | + "\n", |
| 335 | + "\n", |
| 336 | + "\n", |
| 337 | + "\n", |
| 338 | + "\n", |
| 339 | + "\n", |
| 340 | + "\n", |
| 341 | + "\n", |
| 342 | + "\n", |
| 343 | + "\n", |
| 344 | + "\n", |
| 345 | + "\n", |
| 346 | + "\n" |
| 347 | + ] |
| 348 | + } |
| 349 | + ], |
| 350 | + "metadata": { |
| 351 | + "instance_type": "ml.t3.medium", |
| 352 | + "kernelspec": { |
| 353 | + "display_name": "ag", |
| 354 | + "language": "python", |
| 355 | + "name": "python3" |
| 356 | + }, |
| 357 | + "language_info": { |
| 358 | + "codemirror_mode": { |
| 359 | + "name": "ipython", |
| 360 | + "version": 3 |
| 361 | + }, |
| 362 | + "file_extension": ".py", |
| 363 | + "mimetype": "text/x-python", |
| 364 | + "name": "python", |
| 365 | + "nbconvert_exporter": "python", |
| 366 | + "pygments_lexer": "ipython3", |
| 367 | + "version": "3.11.9" |
| 368 | + } |
| 369 | + }, |
| 370 | + "nbformat": 4, |
| 371 | + "nbformat_minor": 5 |
| 372 | +} |
0 commit comments