Skip to content

Commit a1c825b

Browse files
author
Sushanth Sathish Kumar
committed
chore: Add docstrings to functions and add check for the case where instead of endpoint name, the user provides an EndpointInput object
1 parent 34a89e0 commit a1c825b

File tree

5 files changed

+173
-8
lines changed

5 files changed

+173
-8
lines changed

src/sagemaker/dashboard/dashboard_variables.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,33 @@
2020
import json
2121

2222
class DashboardVariable:
23+
"""
24+
Represents a dashboard variable used for dynamic configuration in CloudWatch Dashboards.
25+
26+
Attributes:
27+
variable_type (str): Type of dashboard variable ('property' or 'pattern').
28+
variable_property (str): Property affected by the variable, such as a JSON property or metric dimension.
29+
inputType (str): Type of input field ('input', 'select', or 'radio') for user interaction.
30+
id (str): Identifier for the variable, up to 32 characters.
31+
label (str): Label displayed for the input field (optional, defaults based on context).
32+
search (str): Metric search expression to populate input options (required for 'select' or 'radio').
33+
populateFrom (str): Dimension name used to populate input options from search results.
34+
"""
2335
def __init__(
2436
self, variable_type, variable_property, inputType, variable_id, label, search, populateFrom
2537
):
38+
"""
39+
Initializes a DashboardVariable instance.
40+
41+
Args:
42+
variable_type (str): Type of dashboard variable ('property' or 'pattern').
43+
variable_property (str): Property affected by the variable, such as a JSON property or metric dimension.
44+
inputType (str): Type of input field ('input', 'select', or 'radio') for user interaction.
45+
variable_id (str): Identifier for the variable, up to 32 characters.
46+
label (str, optional): Label displayed for the input field (default is None).
47+
search (str, optional): Metric search expression to populate input options (required for 'select' or 'radio').
48+
populateFrom (str, optional): Dimension name used to populate input options from search results.
49+
"""
2650
self.variable_type = variable_type
2751
self.variable_property = variable_property
2852
self.inputType = inputType
@@ -32,6 +56,12 @@ def __init__(
3256
self.populateFrom = populateFrom
3357

3458
def to_dict(self):
59+
"""
60+
Converts DashboardVariable instance to a dictionary representation.
61+
62+
Returns:
63+
dict: Dictionary containing variable properties suitable for JSON serialization.
64+
"""
3565
variable_properties_dict = {}
3666
if self.variable_type is not None:
3767
variable_properties_dict["type"] = self.variable_type
@@ -50,5 +80,11 @@ def to_dict(self):
5080
return variable_properties_dict
5181

5282
def to_json(self):
83+
"""
84+
Converts DashboardVariable instance to a JSON string.
85+
86+
Returns:
87+
str: JSON string representation of the variable properties.
88+
"""
5389
json.dumps(self.to_dict(), indent=4)
5490

src/sagemaker/dashboard/dashboard_widgets.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,18 @@
2020
import json
2121

2222
class DashboardWidgetProperties:
23+
"""
24+
Represents properties of a dashboard widget used for metrics in CloudWatch.
25+
26+
Attributes:
27+
view (str): Type of visualization ('timeSeries', 'singleValue', 'gauge', 'bar', 'pie', 'table').
28+
stacked (bool): Whether to display the graph as stacked lines (applies to 'timeSeries' view).
29+
metrics (list): Array of metrics configurations for the widget.
30+
region (str): Region associated with the metrics.
31+
period (int): Period in seconds for data points on the graph.
32+
title (str): Title displayed for the graph or number (optional).
33+
markdown (str): Markdown content to display within the widget (optional).
34+
"""
2335
def __init__(
2436
self,
2537
view=None,
@@ -30,6 +42,18 @@ def __init__(
3042
title=None,
3143
markdown=None,
3244
):
45+
"""
46+
Initializes DashboardWidgetProperties instance.
47+
48+
Args:
49+
view (str, optional): Type of visualization ('timeSeries', 'singleValue', 'gauge', 'bar', 'pie', 'table').
50+
stacked (bool, optional): Whether to display the graph as stacked lines (applies to 'timeSeries' view).
51+
metrics (list, optional): Array of metrics configurations for the widget.
52+
region (str, optional): Region associated with the metrics.
53+
period (int, optional): Period in seconds for data points on the graph.
54+
title (str, optional): Title displayed for the graph or number.
55+
markdown (str, optional): Markdown content to display within the widget.
56+
"""
3357
self.view = view
3458
self.stacked = stacked
3559
self.metrics = metrics
@@ -39,6 +63,12 @@ def __init__(
3963
self.markdown = markdown
4064

4165
def to_dict(self):
66+
"""
67+
Converts DashboardWidgetProperties instance to a dictionary representation.
68+
69+
Returns:
70+
dict: Dictionary containing widget properties suitable for JSON serialization.
71+
"""
4272
widget_properties_dict = {}
4373
if self.view is not None:
4474
widget_properties_dict["view"] = self.view
@@ -57,11 +87,35 @@ def to_dict(self):
5787
return widget_properties_dict
5888

5989
def to_json(self):
90+
"""
91+
Converts DashboardWidgetProperties instance to a JSON string.
92+
93+
Returns:
94+
str: JSON string representation of the widget properties.
95+
"""
6096
json.dumps(self.to_dict(), indent=4)
6197

6298

6399
class DashboardWidget:
100+
"""
101+
Represents a widget in a CloudWatch dashboard.
102+
103+
Attributes:
104+
height (int): Height of the widget.
105+
width (int): Width of the widget.
106+
type (str): Type of the widget.
107+
properties (DashboardWidgetProperties): Properties specific to the widget type.
108+
"""
64109
def __init__(self, height, width, widget_type, properties=None):
110+
"""
111+
Initializes DashboardWidget instance.
112+
113+
Args:
114+
height (int): Height of the widget.
115+
width (int): Width of the widget.
116+
widget_type (str): Type of the widget.
117+
properties (DashboardWidgetProperties, optional): Properties specific to the widget type.
118+
"""
65119
self.height = height
66120
self.width = width
67121
self.type = widget_type
@@ -72,6 +126,12 @@ def __init__(self, height, width, widget_type, properties=None):
72126
)
73127

74128
def to_dict(self):
129+
"""
130+
Converts DashboardWidget instance to a dictionary representation.
131+
132+
Returns:
133+
dict: Dictionary containing widget attributes suitable for JSON serialization.
134+
"""
75135
return {
76136
"height": self.height,
77137
"width": self.width,
@@ -80,4 +140,10 @@ def to_dict(self):
80140
}
81141

82142
def to_json(self):
143+
"""
144+
Converts DashboardWidget instance to a JSON string.
145+
146+
Returns:
147+
str: JSON string representation of the widget attributes.
148+
"""
83149
return json.dumps(self.to_dict(), indent=4)

src/sagemaker/dashboard/data_quality_dashboard.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import json
1818
from sagemaker.dashboard.dashboard_variables import DashboardVariable
1919
from sagemaker.dashboard.dashboard_widgets import DashboardWidget, DashboardWidgetProperties
20+
from sagemaker.model_monitor.model_monitoring import EndpointInput
2021

2122
class AutomaticDataQualityDashboard:
2223
DATA_QUALITY_METRICS_ENDPOINT_NAMESPACE = (
@@ -27,7 +28,11 @@ class AutomaticDataQualityDashboard:
2728
)
2829

2930
def __init__(self, endpoint_name, monitoring_schedule_name, batch_transform_input, region_name):
30-
self.endpoint = endpoint_name
31+
if type(endpoint_name) == EndpointInput:
32+
self.endpoint = endpoint_name.endpoint_name
33+
else:
34+
self.endpoint = endpoint_name
35+
3136
self.monitoring_schedule = monitoring_schedule_name
3237
self.batch_transform = batch_transform_input
3338
self.region = region_name

src/sagemaker/dashboard/model_quality_dashboard.py

Lines changed: 63 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,31 @@
11
import json
22
from sagemaker.dashboard.dashboard_widgets import DashboardWidget, DashboardWidgetProperties
3-
3+
from sagemaker.model_monitor import EndpointInput
44
class AutomaticModelQualityDashboard:
5+
"""
6+
Represents a dashboard for automatic model quality metrics in Amazon SageMaker.
7+
8+
Attributes:
9+
MODEL_QUALITY_METRICS_ENDPOINT_NAMESPACE (str): Namespace for model metrics at endpoint level.
10+
MODEL_QUALITY_METRICS_BATCH_NAMESPACE (str): Namespace for model metrics at batch transform level.
11+
REGRESSION_MODEL_QUALITY_METRICS (list): List of regression model quality metrics and their graphs.
12+
BINARY_CLASSIFICATION_MODEL_QUALITY_METRICS (list): List of binary classification model quality metrics and their graphs.
13+
MULTICLASS_CLASSIFICATION_MODEL_QUALITY_METRICS (list): List of multiclass classification model quality metrics and their graphs.
14+
15+
Methods:
16+
__init__(self, endpoint_name, monitoring_schedule_name, batch_transform_input, problem_type, region_name):
17+
Initializes an AutomaticModelQualityDashboard instance.
18+
19+
_generate_widgets(self):
20+
Generates widgets based on the specified problem type and metrics.
21+
22+
to_dict(self):
23+
Converts the dashboard instance to a dictionary representation.
24+
25+
to_json(self):
26+
Converts the dashboard instance to a JSON string.
27+
"""
28+
529
MODEL_QUALITY_METRICS_ENDPOINT_NAMESPACE = (
630
"{aws/sagemaker/Endpoints/model-metrics,Endpoint,MonitoringSchedule}"
731
)
@@ -60,7 +84,21 @@ class AutomaticModelQualityDashboard:
6084
]
6185

6286
def __init__(self, endpoint_name, monitoring_schedule_name, batch_transform_input, problem_type, region_name):
63-
self.endpoint = endpoint_name
87+
"""
88+
Initializes an AutomaticModelQualityDashboard instance.
89+
90+
Args:
91+
endpoint_name (str): Name of the SageMaker endpoint.
92+
monitoring_schedule_name (str): Name of the monitoring schedule.
93+
batch_transform_input (str): Batch transform input (can be None).
94+
problem_type (str): Type of problem ('Regression', 'BinaryClassification', or 'MulticlassClassification').
95+
region_name (str): AWS region name.
96+
"""
97+
if type(endpoint_name) == EndpointInput:
98+
self.endpoint = endpoint_name.endpoint_name
99+
else:
100+
self.endpoint = endpoint_name
101+
64102
self.monitoring_schedule = monitoring_schedule_name
65103
self.batch_transform = batch_transform_input
66104
self.region = region_name
@@ -72,6 +110,12 @@ def __init__(self, endpoint_name, monitoring_schedule_name, batch_transform_inpu
72110

73111

74112
def _generate_widgets(self):
113+
"""
114+
Generates widgets based on the specified problem type and metrics.
115+
116+
Returns:
117+
list: List of DashboardWidget instances representing each metric graph.
118+
"""
75119
list_of_widgets = []
76120
metrics_to_graph = None
77121
if (self.problem_type == "Regression"):
@@ -82,11 +126,12 @@ def _generate_widgets(self):
82126
metrics_to_graph = AutomaticModelQualityDashboard.MULTICLASS_CLASSIFICATION_MODEL_QUALITY_METRICS
83127
else:
84128
raise ValueError("Parameter problem_type is invalid. Valid options are Regression, BinaryClassification, or MulticlassClassification.")
85-
129+
86130
for graphs_per_line in metrics_to_graph:
87131
for graph in graphs_per_line:
88132
graph_title = graph[0]
89133
graph_metrics = graph[1]
134+
metrics_string = " OR ".join(graph_metrics)
90135
if self.batch_transform is not None:
91136
graph_properties = DashboardWidgetProperties(
92137
view="timeSeries",
@@ -96,7 +141,7 @@ def _generate_widgets(self):
96141
{
97142
"expression": (
98143
f"SEARCH( '{AutomaticModelQualityDashboard.MODEL_QUALITY_METRICS_BATCH_NAMESPACE} "
99-
f"{" OR ".join(graph_metrics)}"
144+
f"{metrics_string} "
100145
f"MonitoringSchedule=\"{self.monitoring_schedule}\" ', "
101146
f"'Average')"
102147
)
@@ -115,7 +160,7 @@ def _generate_widgets(self):
115160
{
116161
"expression": (
117162
f"SEARCH( '{AutomaticModelQualityDashboard.MODEL_QUALITY_METRICS_ENDPOINT_NAMESPACE} "
118-
f"{" OR ".join(graph_metrics)}"
163+
f"{metrics_string} "
119164
f"Endpoint=\"{self.endpoint}\" "
120165
f"MonitoringSchedule=\"{self.monitoring_schedule}\" ', "
121166
f"'Average')"
@@ -128,16 +173,28 @@ def _generate_widgets(self):
128173
)
129174
list_of_widgets.append(
130175
DashboardWidget(
131-
height=8, width=24//len(graph_metrics), widget_type="metric", properties=graph_properties
176+
height=8, width=24//len(graphs_per_line), widget_type="metric", properties=graph_properties
132177
)
133178
)
134179

135180
return list_of_widgets
136181

137182
def to_dict(self):
183+
"""
184+
Converts the AutomaticModelQualityDashboard instance to a dictionary representation.
185+
186+
Returns:
187+
dict: Dictionary containing the dashboard widgets.
188+
"""
138189
return {
139190
"widgets": [widget.to_dict() for widget in self.dashboard["widgets"]],
140191
}
141192

142193
def to_json(self):
194+
"""
195+
Converts the AutomaticModelQualityDashboard instance to a JSON string.
196+
197+
Returns:
198+
str: JSON string representation of the dashboard widgets.
199+
"""
143200
return json.dumps(self.to_dict(), indent=4)

src/sagemaker/model_monitor/model_monitoring.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3290,10 +3290,11 @@ def create_monitoring_schedule(
32903290
cw_client = self.sagemaker_session.boto_session.client("cloudwatch")
32913291
cw_client.put_dashboard(
32923292
DashboardName=dashboard_name,
3293-
DashboardBody=AutomaticDataQualityDashboard(
3293+
DashboardBody=AutomaticModelQualityDashboard(
32943294
endpoint_name=endpoint_input,
32953295
monitoring_schedule_name=monitor_schedule_name,
32963296
batch_transform_input=batch_transform_input,
3297+
problem_type=problem_type,
32973298
region_name=self.sagemaker_session.boto_region_name,
32983299
).to_json(),
32993300
)

0 commit comments

Comments
 (0)