Skip to content

Commit 2150bad

Browse files
author
Sushanth Sathish Kumar
committed
fix: Fix circular dependecy with EndpointInput and dashboards
1 parent fa5b5f0 commit 2150bad

File tree

7 files changed

+154
-42
lines changed

7 files changed

+154
-42
lines changed

src/sagemaker/dashboard/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from __future__ import absolute_import
2+
3+
from sagemaker.dashboard.data_quality_dashboard import AutomaticDataQualityDashboard
4+
from sagemaker.dashboard.model_quality_dashboard import AutomaticModelQualityDashboard

src/sagemaker/dashboard/dashboard_variables.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
to publish the generated dashboards. To be used to aid dashboard creation in ClarifyModelMonitor
1717
and ModelMonitor.
1818
"""
19-
19+
from __future__ import absolute_import
2020
import json
2121

2222

src/sagemaker/dashboard/dashboard_widgets.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
to publish the generated dashboards. To be used to aid dashboard creation in ClarifyModelMonitor
1717
and ModelMonitor.
1818
"""
19-
19+
from __future__ import absolute_import
2020
import json
2121

2222

src/sagemaker/dashboard/data_quality_dashboard.py

Lines changed: 100 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,53 @@
1313
"""This module the wrapper class for data quality dashboard. To be used to aid dashboard
1414
creation in ModelMonitor.
1515
"""
16+
from __future__ import absolute_import
1617

1718
import json
1819
from sagemaker.dashboard.dashboard_variables import DashboardVariable
1920
from sagemaker.dashboard.dashboard_widgets import DashboardWidget, DashboardWidgetProperties
20-
from sagemaker.model_monitor.model_monitoring import EndpointInput
2121

2222

2323
class AutomaticDataQualityDashboard:
24+
"""A wrapper class for creating a data quality dashboard to aid ModelMonitor dashboard creation.
25+
26+
This class generates dashboard variables and widgets based on the endpoint and monitoring
27+
schedule provided.
28+
29+
Attributes:
30+
DATA_QUALITY_METRICS_ENDPOINT_NAMESPACE (str): Namespace for endpoint data quality metrics.
31+
DATA_QUALITY_METRICS_BATCH_NAMESPACE (str): Namespace for batch transform data quality metrics.
32+
33+
Methods:
34+
__init__(self, endpoint_name, monitoring_schedule_name, batch_transform_input, region_name):
35+
Initializes the AutomaticDataQualityDashboard instance.
36+
37+
_generate_variables(self):
38+
Generates variables for the dashboard based on whether batch transform is used or not.
39+
40+
_generate_type_counts_widget(self):
41+
Generates a widget for displaying type counts based on endpoint or batch transform.
42+
43+
_generate_null_counts_widget(self):
44+
Generates a widget for displaying null and non-null counts based on endpoint or batch transform.
45+
46+
_generate_estimated_unique_values_widget(self):
47+
Generates a widget for displaying estimated unique values based on endpoint or batch transform.
48+
49+
_generate_completeness_widget(self):
50+
Generates a widget for displaying completeness based on endpoint or batch transform.
51+
52+
_generate_baseline_drift_widget(self):
53+
Generates a widget for displaying baseline drift based on endpoint or batch transform.
54+
55+
to_dict(self):
56+
Converts the dashboard configuration to a dictionary representation.
57+
58+
to_json(self):
59+
Converts the dashboard configuration to a JSON formatted string.
60+
61+
"""
62+
2463
DATA_QUALITY_METRICS_ENDPOINT_NAMESPACE = (
2564
"{aws/sagemaker/Endpoints/data-metrics,Endpoint,Feature,MonitoringSchedule}"
2665
)
@@ -29,11 +68,19 @@ class AutomaticDataQualityDashboard:
2968
)
3069

3170
def __init__(self, endpoint_name, monitoring_schedule_name, batch_transform_input, region_name):
32-
if type(endpoint_name) == EndpointInput:
33-
self.endpoint = endpoint_name.endpoint_name
34-
else:
35-
self.endpoint = endpoint_name
71+
"""Initializes an instance of AutomaticDataQualityDashboard.
72+
73+
Args:
74+
endpoint_name (str or EndpointInput): Name of the endpoint or EndpointInput object.
75+
monitoring_schedule_name (str): Name of the monitoring schedule.
76+
batch_transform_input (str): Name of the batch transform input.
77+
region_name (str): AWS region name.
78+
79+
If endpoint_name is of type EndpointInput, it extracts endpoint_name from it.
80+
81+
"""
3682

83+
self.endpoint = endpoint_name
3784
self.monitoring_schedule = monitoring_schedule_name
3885
self.batch_transform = batch_transform_input
3986
self.region = region_name
@@ -57,6 +104,12 @@ def __init__(self, endpoint_name, monitoring_schedule_name, batch_transform_inpu
57104
}
58105

59106
def _generate_variables(self):
107+
"""Generates dashboard variables based on the presence of batch transform.
108+
109+
Returns:
110+
list: List of DashboardVariable objects.
111+
112+
"""
60113
if self.batch_transform is not None:
61114
return [
62115
DashboardVariable(
@@ -83,6 +136,12 @@ def _generate_variables(self):
83136
]
84137

85138
def _generate_type_counts_widget(self):
139+
"""Generates a widget for displaying type counts based on endpoint or batch transform.
140+
141+
Returns:
142+
DashboardWidget: A DashboardWidget object configured for type counts.
143+
144+
"""
86145
if self.batch_transform is not None:
87146
type_counts_widget_properties = DashboardWidgetProperties(
88147
view="timeSeries",
@@ -139,6 +198,12 @@ def _generate_type_counts_widget(self):
139198
)
140199

141200
def _generate_null_counts_widget(self):
201+
"""Generates a widget for displaying null and non-null counts based on endpoint or batch transform.
202+
203+
Returns:
204+
DashboardWidget: A DashboardWidget object configured for null counts.
205+
206+
"""
142207
if self.batch_transform is not None:
143208
null_counts_widget_properties = DashboardWidgetProperties(
144209
view="timeSeries",
@@ -186,6 +251,12 @@ def _generate_null_counts_widget(self):
186251
)
187252

188253
def _generate_estimated_unique_values_widget(self):
254+
"""Generates a widget for displaying estimated unique values based on endpoint or batch transform.
255+
256+
Returns:
257+
DashboardWidget: A DashboardWidget object configured for estimated unique values.
258+
259+
"""
189260
if self.batch_transform is not None:
190261
estimated_unique_vals_widget_properties = DashboardWidgetProperties(
191262
view="timeSeries",
@@ -237,6 +308,12 @@ def _generate_estimated_unique_values_widget(self):
237308
)
238309

239310
def _generate_completeness_widget(self):
311+
"""Generates a widget for displaying completeness based on endpoint or batch transform.
312+
313+
Returns:
314+
DashboardWidget: A DashboardWidget object configured for completeness.
315+
316+
"""
240317
if self.batch_transform is not None:
241318
completeness_widget_properties = DashboardWidgetProperties(
242319
view="timeSeries",
@@ -285,6 +362,12 @@ def _generate_completeness_widget(self):
285362
)
286363

287364
def _generate_baseline_drift_widget(self):
365+
"""Generates a widget for displaying baseline drift based on endpoint or batch transform.
366+
367+
Returns:
368+
DashboardWidget: A DashboardWidget object configured for baseline drift.
369+
370+
"""
288371
if self.batch_transform is not None:
289372
baseline_drift_widget_properties = DashboardWidgetProperties(
290373
view="timeSeries",
@@ -332,10 +415,22 @@ def _generate_baseline_drift_widget(self):
332415
)
333416

334417
def to_dict(self):
418+
"""Converts the AutomaticDataQualityDashboard configuration to a dictionary representation.
419+
420+
Returns:
421+
dict: A dictionary containing variables and widgets configurations.
422+
423+
"""
335424
return {
336425
"variables": [var.to_dict() for var in self.dashboard["variables"]],
337426
"widgets": [widget.to_dict() for widget in self.dashboard["widgets"]],
338427
}
339428

340429
def to_json(self):
430+
"""Converts the AutomaticDataQualityDashboard configuration to a JSON formatted string.
431+
432+
Returns:
433+
str: A JSON formatted string representation of the dashboard configuration.
434+
435+
"""
341436
return json.dumps(self.to_dict(), indent=4)

src/sagemaker/dashboard/model_quality_dashboard.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
1+
from __future__ import absolute_import
12
import json
23
from sagemaker.dashboard.dashboard_widgets import DashboardWidget, DashboardWidgetProperties
3-
from sagemaker.model_monitor import EndpointInput
44

55

66
class AutomaticModelQualityDashboard:
7-
"""
8-
Represents a dashboard for automatic model quality metrics in Amazon SageMaker.
7+
"""Represents a dashboard for automatic model quality metrics in Amazon SageMaker.
98
109
Attributes:
1110
MODEL_QUALITY_METRICS_ENDPOINT_NAMESPACE (str): Namespace for model metrics at endpoint level.
@@ -106,11 +105,7 @@ def __init__(
106105
problem_type (str): Type of problem ('Regression', 'BinaryClassification', or 'MulticlassClassification').
107106
region_name (str): AWS region name.
108107
"""
109-
if type(endpoint_name) == EndpointInput:
110-
self.endpoint = endpoint_name.endpoint_name
111-
else:
112-
self.endpoint = endpoint_name
113-
108+
self.endpoint = endpoint_name
114109
self.monitoring_schedule = monitoring_schedule_name
115110
self.batch_transform = batch_transform_input
116111
self.region = region_name

0 commit comments

Comments
 (0)