Skip to content

Commit 43d1e0a

Browse files
authored
Add JumpStart time series forecasting notebook example (#4750)
* Add JumpStart time series forecasting notebook example * Fix line length * Add pointer to the notebook * Update index.rst
1 parent 29ebe24 commit 43d1e0a

File tree

3 files changed

+374
-0
lines changed

3 files changed

+374
-0
lines changed

generative_ai/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,3 +26,4 @@ These examples showcases Amazon SageMaker's capabilities in the exciting field o
2626
- [Retrieval-Augmented Generation: Question Answering using LangChain and Cohere's Generate and Embedding Models from SageMaker JumpStart](sm-jumpstart_rag_question_answering_with_cohere_and_langchain.ipynb)
2727
- [Introduction to JumpStart - Text to Image](sm-jumpstart_stable_diffusion_text_to_image.ipynb)
2828
- [Introduction to JumpStart - Text Embedding](sm-jumpstart_text_embedding.ipynb)
29+
- [Introduction to JumpStart - Time Series Forecasting](sm-jumpstart_time_series_forecasting.ipynb)
Lines changed: 372 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,372 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"id": "e960490f",
6+
"metadata": {},
7+
"source": [
8+
"# Introduction to SageMaker JumpStart - Time Series Forecasting with Chronos"
9+
]
10+
},
11+
{
12+
"cell_type": "markdown",
13+
"id": "c1027fd6",
14+
"metadata": {},
15+
"source": [
16+
"---\n",
17+
"\n",
18+
"This notebook's CI test result for us-west-2 is as follows. CI test results in other regions can be found at the end of the notebook. \n",
19+
"\n",
20+
"![This us-west-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-west-2/generative_ai|sm-jumpstart_time_series_forecasting.ipynb)\n",
21+
"\n",
22+
"---"
23+
]
24+
},
25+
{
26+
"cell_type": "markdown",
27+
"id": "6ecc10fd",
28+
"metadata": {},
29+
"source": [
30+
"In this demo notebook, we demonstrate how to use the SageMaker Python SDK to deploy a SageMaker JumpStart time series forecasting model and invoke the endpoint."
31+
]
32+
},
33+
{
34+
"cell_type": "markdown",
35+
"id": "2b4e4586",
36+
"metadata": {},
37+
"source": [
38+
"## Setup\n",
39+
"First, upgrade to the latest sagemaker SDK to ensure all available models are deployable."
40+
]
41+
},
42+
{
43+
"cell_type": "code",
44+
"execution_count": 1,
45+
"id": "9f1c969a",
46+
"metadata": {},
47+
"outputs": [
48+
{
49+
"name": "stdout",
50+
"output_type": "stream",
51+
"text": [
52+
"Note: you may need to restart the kernel to use updated packages.\n"
53+
]
54+
}
55+
],
56+
"source": [
57+
"%pip install sagemaker --upgrade --quiet"
58+
]
59+
},
60+
{
61+
"cell_type": "markdown",
62+
"id": "b958b25b",
63+
"metadata": {},
64+
"source": [
65+
"Select the desired model to deploy. The provided dropdown filters all time series forecasting models available in SageMaker JumpStart."
66+
]
67+
},
68+
{
69+
"cell_type": "code",
70+
"execution_count": 2,
71+
"id": "050856e9",
72+
"metadata": {},
73+
"outputs": [
74+
{
75+
"name": "stdout",
76+
"output_type": "stream",
77+
"text": [
78+
"sagemaker.config INFO - Not applying SDK defaults from location: /etc/xdg/sagemaker/config.yaml\n",
79+
"sagemaker.config INFO - Not applying SDK defaults from location: /home/shchuro/.config/sagemaker/config.yaml\n"
80+
]
81+
},
82+
{
83+
"data": {
84+
"application/vnd.jupyter.widget-view+json": {
85+
"model_id": "728a6ed928854101b1fc0616bc6ea4b3",
86+
"version_major": 2,
87+
"version_minor": 0
88+
},
89+
"text/plain": [
90+
"Dropdown(description='Select a JumpStart time series forecasting model:', index=2, layout=Layout(width='max-co…"
91+
]
92+
},
93+
"metadata": {},
94+
"output_type": "display_data"
95+
}
96+
],
97+
"source": [
98+
"from ipywidgets import Dropdown\n",
99+
"from sagemaker.jumpstart.notebook_utils import list_jumpstart_models\n",
100+
"\n",
101+
"dropdown = Dropdown(\n",
102+
" options=list_jumpstart_models(filter=\"task == forecasting\"),\n",
103+
" value=\"autogluon-forecasting-chronos-t5-small\",\n",
104+
" description=\"Select a JumpStart time series forecasting model:\",\n",
105+
" style={\"description_width\": \"initial\"},\n",
106+
" layout={\"width\": \"max-content\"},\n",
107+
")\n",
108+
"display(dropdown)"
109+
]
110+
},
111+
{
112+
"cell_type": "code",
113+
"execution_count": 3,
114+
"id": "1b257502",
115+
"metadata": {},
116+
"outputs": [],
117+
"source": [
118+
"model_id = dropdown.value\n",
119+
"model_version = \"*\""
120+
]
121+
},
122+
{
123+
"cell_type": "markdown",
124+
"id": "6f2fc5d2",
125+
"metadata": {},
126+
"source": [
127+
"## Deploy model\n",
128+
"\n",
129+
"Create a `JumpStartModel` object, which initializes default model configurations conditioned on the selected instance type. JumpStart already sets a default instance type, but you can deploy the model on other instance types by passing `instance_type` to the `JumpStartModel` class."
130+
]
131+
},
132+
{
133+
"cell_type": "code",
134+
"execution_count": 6,
135+
"id": "a4c5bb0e",
136+
"metadata": {},
137+
"outputs": [],
138+
"source": [
139+
"from sagemaker.jumpstart.model import JumpStartModel\n",
140+
"\n",
141+
"model = JumpStartModel(model_id=model_id, model_version=model_version)"
142+
]
143+
},
144+
{
145+
"cell_type": "markdown",
146+
"id": "67eeeab7",
147+
"metadata": {},
148+
"source": [
149+
"You can now deploy the model using SageMaker JumpStart. The deployment might take a few minutes. "
150+
]
151+
},
152+
{
153+
"cell_type": "code",
154+
"execution_count": 7,
155+
"id": "3c7726a9",
156+
"metadata": {},
157+
"outputs": [
158+
{
159+
"name": "stdout",
160+
"output_type": "stream",
161+
"text": [
162+
"----------!"
163+
]
164+
}
165+
],
166+
"source": [
167+
"predictor = model.deploy()"
168+
]
169+
},
170+
{
171+
"cell_type": "markdown",
172+
"id": "a4849068",
173+
"metadata": {},
174+
"source": [
175+
"## Invoke the endpoint\n",
176+
"\n",
177+
"This section demonstrates how to invoke the endpoint using example payloads that are retrieved programmatically from the `JumpStartModel` object. You can replace these example payloads with your own payloads."
178+
]
179+
},
180+
{
181+
"cell_type": "code",
182+
"execution_count": 8,
183+
"id": "27637734",
184+
"metadata": {},
185+
"outputs": [],
186+
"source": [
187+
"from pprint import pformat\n",
188+
"\n",
189+
"\n",
190+
"def nested_round(data, decimals=3):\n",
191+
" \"\"\"Round numbers, including nested dicts and list.\"\"\"\n",
192+
" if isinstance(data, float):\n",
193+
" return round(data, decimals)\n",
194+
" elif isinstance(data, list):\n",
195+
" return [nested_round(item, decimals) for item in data]\n",
196+
" elif isinstance(data, dict):\n",
197+
" return {key: nested_round(value, decimals) for key, value in data.items()}\n",
198+
" else:\n",
199+
" return data\n",
200+
"\n",
201+
"\n",
202+
"def pretty_format(data):\n",
203+
" return pformat(nested_round(data), width=150, sort_dicts=False)"
204+
]
205+
},
206+
{
207+
"cell_type": "code",
208+
"execution_count": 15,
209+
"id": "775c2325",
210+
"metadata": {},
211+
"outputs": [
212+
{
213+
"name": "stdout",
214+
"output_type": "stream",
215+
"text": [
216+
"Input:\n",
217+
" {'inputs': [{'target': [0.0, 4.0, 5.0, 1.5, -3.0, -5.0, -3.0, 1.5, 5.0, 4.0, 0.0, -4.0, -5.0, -1.5, 3.0, 5.0, 3.0, -1.5, -5.0, -4.0]}],\n",
218+
" 'parameters': {'prediction_length': 10}}\n",
219+
"\n",
220+
"Output:\n",
221+
" {'predictions': [{'mean': [-0.488, 3.101, 3.086, 0.436, -2.867, -3.924, -2.258, 0.686, 2.456, 2.089],\n",
222+
" '0.1': [-4.331, 1.351, -0.04, -3.467, -5.713, -5.051, -5.056, -3.247, -1.907, -1.898],\n",
223+
" '0.5': [0.0, 3.507, 4.012, 0.0, -3.003, -4.903, -3.015, 1.501, 3.003, 3.003],\n",
224+
" '0.9': [1.652, 4.997, 4.997, 4.11, 0.0, -2.215, 1.602, 4.997, 4.997, 4.997]}]}\n",
225+
"\n",
226+
"===============\n",
227+
"\n",
228+
"Input:\n",
229+
" {'inputs': [{'target': [1.0, 2.0, 3.0, 2.0, 0.5, 2.0, 3.0, 2.0, 1.0], 'item_id': 'product_A', 'start': '2024-01-01T01:00:00'},\n",
230+
" {'target': [5.4, 3.0, 3.0, 2.0, 1.5, 2.0, -1.0], 'item_id': 'product_B', 'start': '2024-02-02T03:00:00'}],\n",
231+
" 'parameters': {'prediction_length': 5, 'freq': '1h', 'quantile_levels': [0.05, 0.5, 0.95], 'num_samples': 30, 'batch_size': 2}}\n",
232+
"\n",
233+
"Output:\n",
234+
" {'predictions': [{'mean': [1.731, 1.498, 1.764, 1.632, 1.465],\n",
235+
" '0.05': [0.224, 0.497, 0.224, 0.0, 0.0],\n",
236+
" '0.5': [0.995, 0.995, 1.25, 1.499, 0.995],\n",
237+
" '0.95': [4.005, 2.997, 3.278, 3.999, 3.544],\n",
238+
" 'item_id': 'product_A',\n",
239+
" 'start': '2024-01-01T10:00:00'},\n",
240+
" {'mean': [0.084, 0.916, 0.384, 1.205, 1.481],\n",
241+
" '0.05': [-1.273, -0.726, -1.537, -0.358, -0.863],\n",
242+
" '0.5': [0.0, 0.872, 0.0, 1.012, 1.059],\n",
243+
" '0.95': [2.006, 3.109, 2.552, 3.0, 4.887],\n",
244+
" 'item_id': 'product_B',\n",
245+
" 'start': '2024-02-02T10:00:00'}]}\n",
246+
"\n",
247+
"===============\n",
248+
"\n"
249+
]
250+
}
251+
],
252+
"source": [
253+
"for payload in model.retrieve_all_examples():\n",
254+
" response = predictor.predict(payload.body)\n",
255+
" print(\"Input:\\n\", pretty_format(payload.body), end=\"\\n\\n\")\n",
256+
" print(\"Output:\\n\", pretty_format(response))\n",
257+
" print(\"\\n===============\\n\")"
258+
]
259+
},
260+
{
261+
"cell_type": "markdown",
262+
"id": "335471d8",
263+
"metadata": {},
264+
"source": [
265+
"The payload for Chronos models must be structured as follows.\n",
266+
"* **inputs** (required): List with at most 64 time series that need to be forecasted. Each time series is represented by a dictionary with the following keys:\n",
267+
" * **target** (required): List of observed numeric time series values. \n",
268+
" - It is recommended that each time series contains at least 30 observations.\n",
269+
" - If any time series contains fewer than 5 observations, an error will be raised.\n",
270+
" * **item_id**: String that uniquely identifies each time series. \n",
271+
" - If provided, the ID must be unique for each time series.\n",
272+
" - If provided, then the endpoint response will also include the **item_id** field for each forecast.\n",
273+
" * **start**: Timestamp of the first time series observation in ISO format (`YYYY-MM-DD` or `YYYY-MM-DDThh:mm:ss`). \n",
274+
" - If **start** field is provided, then **freq** must also be provided as part of **parameters**.\n",
275+
" - If provided, then the endpoint response will also include the **start** field indicating the first timestamp of each forecast.\n",
276+
"* **parameters**: Optional parameters to configure the model.\n",
277+
" * **prediction_length**: Integer corresponding to the number of future time series values that need to be predicted. \n",
278+
" - Recommended to keep prediction_length <= 64 since larger values will result in inaccurate quantile forecasts. Values above 1000 will raise an error.\n",
279+
" * **quantile_levels**: List of floats in range (0, 1) specifying which quantiles should should be included in the probabilistic forecast. Defaults to `[0.1, 0.5, 0.9]`. \n",
280+
" * **freq**: Frequency of the time series observations in [pandas-compatible format](https://pandas.pydata.org/pandas-docs/stable/user_guide/timeseries.html#offset-aliases). For example, `1h` for hourly data or `2W` for bi-weekly data. \n",
281+
" - If **freq** is provided, then **start** must also be provided for each time series in **inputs**.\n",
282+
" * **num_samples**: Number of sample trajectories generated by the Chronos model during inference. Larger values may improve accuracy but increase memory consumption and slow down inference. Defaults to `20`.\n",
283+
" * **batch_size**: Number of time series processed in parallel by the model. Larger values speed up inference but may lead to out of memory errors.\n",
284+
"\n",
285+
"All keys not marked with (required) are optional.\n",
286+
"\n",
287+
"The endpoint response contains the probabilistic (quantile) forecast for each time series included in the request."
288+
]
289+
},
290+
{
291+
"cell_type": "markdown",
292+
"id": "aa5c2c0a",
293+
"metadata": {},
294+
"source": [
295+
"## Clean up the endpoint\n",
296+
"Don't forget to clean up resources when finished to avoid unnecessary charges."
297+
]
298+
},
299+
{
300+
"cell_type": "code",
301+
"execution_count": 16,
302+
"id": "b1a4059e",
303+
"metadata": {},
304+
"outputs": [],
305+
"source": [
306+
"predictor.delete_predictor()"
307+
]
308+
},
309+
{
310+
"cell_type": "markdown",
311+
"id": "1b1d70fc",
312+
"metadata": {},
313+
"source": [
314+
"## Notebook CI Test Results\n",
315+
"\n",
316+
"This notebook was tested in multiple regions. The test results are as follows, except for us-west-2 which is shown at the top of the notebook.\n",
317+
"\n",
318+
"![This us-east-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-east-1/generative_ai|sm-jumpstart_time_series_forecasting.ipynb)\n",
319+
"\n",
320+
"![This us-east-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-east-2/generative_ai|sm-jumpstart_time_series_forecasting.ipynb)\n",
321+
"\n",
322+
"![This us-west-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-west-1/generative_ai|sm-jumpstart_time_series_forecasting.ipynb)\n",
323+
"\n",
324+
"![This ca-central-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ca-central-1/generative_ai|sm-jumpstart_time_series_forecasting.ipynb)\n",
325+
"\n",
326+
"![This sa-east-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/sa-east-1/generative_ai|sm-jumpstart_time_series_forecasting.ipynb)\n",
327+
"\n",
328+
"![This eu-west-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-west-1/generative_ai|sm-jumpstart_time_series_forecasting.ipynb)\n",
329+
"\n",
330+
"![This eu-west-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-west-2/generative_ai|sm-jumpstart_time_series_forecasting.ipynb)\n",
331+
"\n",
332+
"![This eu-west-3 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-west-3/generative_ai|sm-jumpstart_time_series_forecasting.ipynb)\n",
333+
"\n",
334+
"![This eu-central-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-central-1/generative_ai|sm-jumpstart_time_series_forecasting.ipynb)\n",
335+
"\n",
336+
"![This eu-north-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-north-1/generative_ai|sm-jumpstart_time_series_forecasting.ipynb)\n",
337+
"\n",
338+
"![This ap-southeast-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-southeast-1/generative_ai|sm-jumpstart_time_series_forecasting.ipynb)\n",
339+
"\n",
340+
"![This ap-southeast-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-southeast-2/generative_ai|sm-jumpstart_time_series_forecasting.ipynb)\n",
341+
"\n",
342+
"![This ap-northeast-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-northeast-1/generative_ai|sm-jumpstart_time_series_forecasting.ipynb)\n",
343+
"\n",
344+
"![This ap-northeast-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-northeast-2/generative_ai|sm-jumpstart_time_series_forecasting.ipynb)\n",
345+
"\n",
346+
"![This ap-south-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-south-1/generative_ai|sm-jumpstart_time_series_forecasting.ipynb)\n"
347+
]
348+
}
349+
],
350+
"metadata": {
351+
"instance_type": "ml.t3.medium",
352+
"kernelspec": {
353+
"display_name": "ag",
354+
"language": "python",
355+
"name": "python3"
356+
},
357+
"language_info": {
358+
"codemirror_mode": {
359+
"name": "ipython",
360+
"version": 3
361+
},
362+
"file_extension": ".py",
363+
"mimetype": "text/x-python",
364+
"name": "python",
365+
"nbconvert_exporter": "python",
366+
"pygments_lexer": "ipython3",
367+
"version": "3.11.9"
368+
}
369+
},
370+
"nbformat": 4,
371+
"nbformat_minor": 5
372+
}

index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,7 @@ We recommend the following notebooks as a broad introduction to the capabilities
185185
generative_ai/sm-jumpstart_rag_question_answering_with_cohere_and_langchain
186186
generative_ai/sm-jumpstart_stable_diffusion_text_to_image
187187
generative_ai/sm-jumpstart_text_embedding
188+
generative_ai/sm-jumpstart_time_series_forecasting
188189

189190
.. toctree::
190191
:maxdepth: 1

0 commit comments

Comments
 (0)