diff --git a/src/sagemaker/dashboard/__init__.py b/src/sagemaker/dashboard/__init__.py new file mode 100644 index 0000000000..63886fa268 --- /dev/null +++ b/src/sagemaker/dashboard/__init__.py @@ -0,0 +1,27 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Imports the classes in this module to simplify customer imports + +Example: + >>> from sagemaker.dashboard import AutomaticDataQualityDashboard + +""" +from __future__ import absolute_import + +from sagemaker.dashboard.data_quality_dashboard import AutomaticDataQualityDashboard # noqa: F401 +from sagemaker.dashboard.model_quality_dashboard import AutomaticModelQualityDashboard # noqa: F401 +from sagemaker.dashboard.dashboard_variables import DashboardVariable # noqa: F401 +from sagemaker.dashboard.dashboard_widgets import ( # noqa: F401 + DashboardWidget, + DashboardWidgetProperties, +) diff --git a/src/sagemaker/dashboard/dashboard_variables.py b/src/sagemaker/dashboard/dashboard_variables.py new file mode 100644 index 0000000000..c55fba2785 --- /dev/null +++ b/src/sagemaker/dashboard/dashboard_variables.py @@ -0,0 +1,87 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module contains code containing wrapper classes for dashboard variables in CloudWatch. + +These classes assist with creating dashboards in Python3 and then using boto3 CloudWatch client +to publish the generated dashboards. To be used to aid dashboard creation in ClarifyModelMonitor +and ModelMonitor. +""" +from __future__ import absolute_import +import json + + +class DashboardVariable: + """Represents a dashboard variable used for dynamic configuration in CloudWatch Dashboards. + + Attributes: + variable_type (str): Type of dashboard variable ('property' or 'pattern'). + variable_property (str): Property affected by the variable, such as metric dimension. + inputType (str): Type of input field ('input', 'select', or 'radio') for user interaction. + id (str): Identifier for the variable, up to 32 characters. + label (str): Label displayed for the input field (optional, defaults based on context). + search (str): Metric search expression to populate fields (required for 'select'). + populateFrom (str): Dimension name used to populate fields from search results. + """ + + def __init__( + self, variable_type, variable_property, inputType, variable_id, label, search, populateFrom + ): + """Initializes a DashboardVariable instance. + + Args: + variable_type (str): Type of dashboard variable ('property' or 'pattern'). + variable_property (str): Property affected by the variable, such as metric dimension. + inputType (str): Type of input field ('input', 'select', or 'radio'). + variable_id (str): Identifier for the variable, up to 32 characters. + label (str, optional): Label displayed for the input field (default is None). + search (str, optional): Metric search expression to populate input options. + populateFrom (str, optional): Dimension name used to populate field. + """ + self.variable_type = variable_type + self.variable_property = variable_property + self.inputType = inputType + self.id = variable_id + self.label = label + self.search = search + self.populateFrom = populateFrom + + def to_dict(self): + """Converts DashboardVariable instance to a dictionary representation. + + Returns: + dict: Dictionary containing variable properties suitable for JSON serialization. + """ + variable_properties_dict = {} + if self.variable_type is not None: + variable_properties_dict["type"] = self.variable_type + if self.variable_property is not None: + variable_properties_dict["property"] = self.variable_property + if self.inputType is not None: + variable_properties_dict["inputType"] = self.inputType + if self.id is not None: + variable_properties_dict["id"] = self.id + if self.label is not None: + variable_properties_dict["label"] = self.label + if self.search is not None: + variable_properties_dict["search"] = self.search + if self.populateFrom is not None: + variable_properties_dict["populateFrom"] = self.populateFrom + return variable_properties_dict + + def to_json(self): + """Converts DashboardVariable instance to a JSON string. + + Returns: + str: JSON string representation of the variable properties. + """ + json.dumps(self.to_dict(), indent=4) diff --git a/src/sagemaker/dashboard/dashboard_widgets.py b/src/sagemaker/dashboard/dashboard_widgets.py new file mode 100644 index 0000000000..e7516e5a02 --- /dev/null +++ b/src/sagemaker/dashboard/dashboard_widgets.py @@ -0,0 +1,144 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module contains code containing wrapper classes for dashboard widgets in CloudWatch. + +These classes assist with creating dashboards in Python3 and then using boto3 CloudWatch client +to publish the generated dashboards. To be used to aid dashboard creation in ClarifyModelMonitor +and ModelMonitor. +""" +from __future__ import absolute_import +import json + + +class DashboardWidgetProperties: + """Represents properties of a dashboard widget used for metrics in CloudWatch. + + Attributes: + view (str): Type of visualization ('timeSeries', 'bar', 'pie', 'table'). + stacked (bool): Whether to display graph as stacked lines (applies to 'timeSeries' view). + metrics (list): Array of metrics configurations for the widget. + region (str): Region associated with the metrics. + period (int): Period in seconds for data points on the graph. + title (str): Title displayed for the graph or number (optional). + markdown (str): Markdown content to display within the widget (optional). + """ + + def __init__( + self, + view=None, + stacked=None, + metrics=None, + region=None, + period=None, + title=None, + markdown=None, + ): + """Initializes DashboardWidgetProperties instance. + + Args: + view (str, optional): Type of visualization ('timeSeries', 'bar', 'pie', 'table'). + stacked (bool, optional): Whether to display the graph as stacked lines. + metrics (list, optional): Array of metrics configurations for the widget. + region (str, optional): Region associated with the metrics. + period (int, optional): Period in seconds for data points on the graph. + title (str, optional): Title displayed for the graph or number. + markdown (str, optional): Markdown content to display within the widget. + """ + self.view = view + self.stacked = stacked + self.metrics = metrics + self.region = region + self.period = period + self.title = title + self.markdown = markdown + + def to_dict(self): + """Converts DashboardWidgetProperties instance to a dictionary representation. + + Returns: + dict: Dictionary containing widget properties suitable for JSON serialization. + """ + widget_properties_dict = {} + if self.view is not None: + widget_properties_dict["view"] = self.view + if self.period is not None: + widget_properties_dict["period"] = self.period + if self.markdown is not None: + widget_properties_dict["markdown"] = self.markdown + if self.stacked is not None: + widget_properties_dict["stacked"] = self.stacked + if self.region is not None: + widget_properties_dict["region"] = self.region + if self.metrics is not None: + widget_properties_dict["metrics"] = self.metrics + if self.title is not None: + widget_properties_dict["title"] = self.title + return widget_properties_dict + + def to_json(self): + """Converts DashboardWidgetProperties instance to a JSON string. + + Returns: + str: JSON string representation of the widget properties. + """ + json.dumps(self.to_dict(), indent=4) + + +class DashboardWidget: + """Represents a widget in a CloudWatch dashboard. + + Attributes: + height (int): Height of the widget. + width (int): Width of the widget. + type (str): Type of the widget. + properties (DashboardWidgetProperties): Properties specific to the widget type. + """ + + def __init__(self, height, width, widget_type, properties=None): + """Initializes DashboardWidget instance. + + Args: + height (int): Height of the widget. + width (int): Width of the widget. + widget_type (str): Type of the widget. + properties (DashboardWidgetProperties, optional): Properties of the widget type. + """ + self.height = height + self.width = width + self.type = widget_type + self.properties = ( + properties + if properties + else DashboardWidgetProperties(None, False, [], None, None, None) + ) + + def to_dict(self): + """Converts DashboardWidget instance to a dictionary representation. + + Returns: + dict: Dictionary containing widget attributes suitable for JSON serialization. + """ + return { + "height": self.height, + "width": self.width, + "type": self.type, + "properties": self.properties.to_dict(), + } + + def to_json(self): + """Converts DashboardWidget instance to a JSON string. + + Returns: + str: JSON string representation of the widget attributes. + """ + return json.dumps(self.to_dict(), indent=4) diff --git a/src/sagemaker/dashboard/data_quality_dashboard.py b/src/sagemaker/dashboard/data_quality_dashboard.py new file mode 100644 index 0000000000..a212f1c891 --- /dev/null +++ b/src/sagemaker/dashboard/data_quality_dashboard.py @@ -0,0 +1,440 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module the wrapper class for data quality dashboard. + +To be used to aid dashboard creation in ModelMonitor. +""" +from __future__ import absolute_import + +import json +from sagemaker.dashboard.dashboard_variables import DashboardVariable +from sagemaker.dashboard.dashboard_widgets import DashboardWidget, DashboardWidgetProperties + + +class AutomaticDataQualityDashboard: + """A wrapper class for creating a data quality dashboard to aid ModelMonitor dashboard creation. + + This class generates dashboard variables and widgets based on the endpoint and monitoring + schedule provided. + + Attributes: + DATA_QUALITY_METRICS_ENDPOINT_NAMESPACE (str): Namespace for endpoint. + DATA_QUALITY_METRICS_BATCH_NAMESPACE (str): Namespace for batch transform. + + Methods: + __init__(self, endpoint_name, monitoring_schedule_name, batch_transform_input, region_name): + Initializes the AutomaticDataQualityDashboard instance. + + _generate_variables(self): + Generates variables for the dashboard based on whether batch transform is used or not. + + _generate_type_counts_widget(self): + Generates a widget for displaying type counts. + + _generate_null_counts_widget(self): + Generates a widget for displaying null and non-null counts. + + _generate_estimated_unique_values_widget(self): + Generates a widget for displaying estimated unique values. + + _generate_completeness_widget(self): + Generates a widget for displaying completeness. + + _generate_baseline_drift_widget(self): + Generates a widget for displaying baseline drift. + + to_dict(self): + Converts the dashboard configuration to a dictionary representation. + + to_json(self): + Converts the dashboard configuration to a JSON formatted string. + + """ + + DATA_QUALITY_METRICS_ENDPOINT_NAMESPACE = ( + "{aws/sagemaker/Endpoints/data-metrics,Endpoint,Feature,MonitoringSchedule}" + ) + DATA_QUALITY_METRICS_BATCH_NAMESPACE = ( + "{aws/sagemaker/ModelMonitoring/data-metrics,Feature,MonitoringSchedule}" + ) + + def __init__(self, endpoint_name, monitoring_schedule_name, batch_transform_input, region_name): + """Initializes an instance of AutomaticDataQualityDashboard. + + Args: + endpoint_name (str or EndpointInput): Name of the endpoint or EndpointInput object. + monitoring_schedule_name (str): Name of the monitoring schedule. + batch_transform_input (str): Name of the batch transform input. + region_name (str): AWS region name. + + If endpoint_name is of type EndpointInput, it extracts endpoint_name from it. + + """ + + self.endpoint = endpoint_name + self.monitoring_schedule = monitoring_schedule_name + self.batch_transform = batch_transform_input + self.region = region_name + + variables = self._generate_variables() + type_counts_widget = self._generate_type_counts_widget() + null_counts_widget = self._generate_null_counts_widget() + estimated_unique_values_widget = self._generate_estimated_unique_values_widget() + completeness_widget = self._generate_completeness_widget() + baseline_drift_widget = self._generate_baseline_drift_widget() + + self.dashboard = { + "variables": variables, + "widgets": [ + type_counts_widget, + null_counts_widget, + estimated_unique_values_widget, + completeness_widget, + baseline_drift_widget, + ], + } + + def _generate_variables(self): + """Generates dashboard variables based on the presence of batch transform. + + Returns: + list: List of DashboardVariable objects. + + """ + if self.batch_transform is not None: + return [ + DashboardVariable( + variable_type="property", + variable_property="Feature", + inputType="select", + variable_id="Feature", + label="Feature", + search=self.DATA_QUALITY_METRICS_BATCH_NAMESPACE + + f' MonitoringSchedule="{self.monitoring_schedule}" ', + populateFrom="Feature", + ) + ] + + return [ + DashboardVariable( + variable_type="property", + variable_property="Feature", + inputType="select", + variable_id="Feature", + label="Feature", + search=self.DATA_QUALITY_METRICS_ENDPOINT_NAMESPACE + + f' Endpoint="{self.endpoint}"' + + f' MonitoringSchedule="{self.monitoring_schedule}" ', + populateFrom="Feature", + ) + ] + + def _generate_type_counts_widget(self): + """Generates a widget for displaying type counts based on endpoint or batch transform. + + Returns: + DashboardWidget: A DashboardWidget object configured for type counts. + + """ + if self.batch_transform is not None: + type_counts_widget_properties = DashboardWidgetProperties( + view="timeSeries", + stacked=False, + metrics=[ + [ + { + "expression": ( + f"SEARCH( '{self.DATA_QUALITY_METRICS_BATCH_NAMESPACE} " + f"%^feature_fractional_counts_.*% OR " + f"%^feature_integral_counts_.*% OR " + f"%^feature_string_counts_.*% OR " + f"%^feature_boolean_counts_.*% OR " + f"%^feature_unknown_counts_.*% " + f'Feature="_" ' + f'MonitoringSchedule="{self.monitoring_schedule}" \', ' + f"'Average')" + ) + } + ] + ], + region=self.region, + title="Type Counts", + ) + + else: + type_counts_widget_properties = DashboardWidgetProperties( + view="timeSeries", + stacked=False, + metrics=[ + [ + { + "expression": ( + f"SEARCH( '{self.DATA_QUALITY_METRICS_ENDPOINT_NAMESPACE} " + f"%^feature_fractional_counts_.*% OR " + f"%^feature_integral_counts_.*% OR " + f"%^feature_string_counts_.*% OR " + f"%^feature_boolean_counts_.*% OR " + f"%^feature_unknown_counts_.*% " + f'Endpoint="{self.endpoint}" ' + f'Feature="_" ' + f'MonitoringSchedule="{self.monitoring_schedule}" \', ' + f"'Average')" + ) + } + ] + ], + region=self.region, + title="Type Counts", + ) + + return DashboardWidget( + height=8, width=12, widget_type="metric", properties=type_counts_widget_properties + ) + + def _generate_null_counts_widget(self): + """Generates a widget for displaying null and non-null counts. + + Returns: + DashboardWidget: A DashboardWidget object configured for null counts. + + """ + if self.batch_transform is not None: + null_counts_widget_properties = DashboardWidgetProperties( + view="timeSeries", + stacked=False, + metrics=[ + [ + { + "expression": ( + f"SEARCH( '{self.DATA_QUALITY_METRICS_BATCH_NAMESPACE} " + f"%^feature_null_.*% OR %^feature_non_null_.*% " + f'Feature="_" ' + f'MonitoringSchedule="{self.monitoring_schedule}" \', ' + f"'Average')" + ) + } + ] + ], + region=self.region, + title="Missing Data Counts", + ) + + else: + null_counts_widget_properties = DashboardWidgetProperties( + view="timeSeries", + stacked=False, + metrics=[ + [ + { + "expression": ( + f"SEARCH( '{self.DATA_QUALITY_METRICS_ENDPOINT_NAMESPACE} " + f"%^feature_null_.*% OR %^feature_non_null_.*% " + f'Endpoint="{self.endpoint}" ' + f'Feature="_" ' + f'MonitoringSchedule="{self.monitoring_schedule}" \', ' + f"'Average')" + ) + } + ] + ], + region=self.region, + title="Missing Data Counts", + ) + return DashboardWidget( + height=8, width=12, widget_type="metric", properties=null_counts_widget_properties + ) + + def _generate_estimated_unique_values_widget(self): + """Generates a widget for displaying estimated unique values. + + Returns: + DashboardWidget: A DashboardWidget object configured for estimated unique values. + + """ + if self.batch_transform is not None: + estimated_unique_vals_widget_properties = DashboardWidgetProperties( + view="timeSeries", + stacked=False, + metrics=[ + [ + { + "expression": ( + f"SEARCH( '{self.DATA_QUALITY_METRICS_BATCH_NAMESPACE} " + f"%^feature_estimated_unique_values_.*% " + f'Feature="_" ' + f'MonitoringSchedule="{self.monitoring_schedule}" \', ' + f"'Average')" + ) + } + ] + ], + region=self.region, + title="Estimated Unique Values", + ) + + else: + estimated_unique_vals_widget_properties = DashboardWidgetProperties( + view="timeSeries", + stacked=False, + metrics=[ + [ + { + "expression": ( + f"SEARCH( '{self.DATA_QUALITY_METRICS_ENDPOINT_NAMESPACE} " + f"%^feature_estimated_unique_values_.*% " + f'Endpoint="{self.endpoint}" ' + f'Feature="_" ' + f'MonitoringSchedule="{self.monitoring_schedule}" \', ' + f"'Average')" + ) + } + ] + ], + region=self.region, + title="Estimated Unique Values", + ) + + return DashboardWidget( + height=8, + width=12, + widget_type="metric", + properties=estimated_unique_vals_widget_properties, + ) + + def _generate_completeness_widget(self): + """Generates a widget for displaying completeness based on endpoint or batch transform. + + Returns: + DashboardWidget: A DashboardWidget object configured for completeness. + + """ + if self.batch_transform is not None: + completeness_widget_properties = DashboardWidgetProperties( + view="timeSeries", + stacked=False, + metrics=[ + [ + { + "expression": ( + f"SEARCH( '{self.DATA_QUALITY_METRICS_BATCH_NAMESPACE} " + f"%^feature_completeness_.*% " + f'Feature="_" ' + f'MonitoringSchedule="{self.monitoring_schedule}" \', ' + f"'Average')" + ) + } + ] + ], + region=self.region, + title="Completeness", + ) + + else: + completeness_widget_properties = DashboardWidgetProperties( + view="timeSeries", + stacked=False, + metrics=[ + [ + { + "expression": ( + f"SEARCH( '{self.DATA_QUALITY_METRICS_ENDPOINT_NAMESPACE} " + f"%^feature_completeness_.*% " + f'Endpoint="{self.endpoint}" ' + f'Feature="_" ' + f'MonitoringSchedule="{self.monitoring_schedule}" \', ' + f"'Average')" + ) + } + ] + ], + region=self.region, + title="Completeness", + ) + + return DashboardWidget( + height=8, width=12, widget_type="metric", properties=completeness_widget_properties + ) + + def _generate_baseline_drift_widget(self): + """Generates a widget for displaying baseline drift based on endpoint or batch transform. + + Returns: + DashboardWidget: A DashboardWidget object configured for baseline drift. + + """ + if self.batch_transform is not None: + baseline_drift_widget_properties = DashboardWidgetProperties( + view="timeSeries", + stacked=False, + metrics=[ + [ + { + "expression": ( + f"SEARCH( '{self.DATA_QUALITY_METRICS_BATCH_NAMESPACE} " + f"%^feature_baseline_drift_.*% " + f'Feature="_" ' + f'MonitoringSchedule="{self.monitoring_schedule}" \', ' + f"'Average')" + ) + } + ] + ], + region=self.region, + title="Baseline Drift", + ) + + else: + baseline_drift_widget_properties = DashboardWidgetProperties( + view="timeSeries", + stacked=False, + metrics=[ + [ + { + "expression": ( + f"SEARCH( '{self.DATA_QUALITY_METRICS_ENDPOINT_NAMESPACE} " + f"%^feature_baseline_drift_.*% " + f'Endpoint="{self.endpoint}" ' + f'Feature="_" ' + f'MonitoringSchedule="{self.monitoring_schedule}" \', ' + f"'Average')" + ) + } + ] + ], + region=self.region, + title="Baseline Drift", + ) + return DashboardWidget( + height=8, width=12, widget_type="metric", properties=baseline_drift_widget_properties + ) + + def to_dict(self): + """Converts the AutomaticDataQualityDashboard configuration to a dictionary representation. + + Returns: + dict: A dictionary containing variables and widgets configurations. + + """ + return { + "variables": [var.to_dict() for var in self.dashboard["variables"]], + "widgets": [widget.to_dict() for widget in self.dashboard["widgets"]], + } + + def to_json(self): + """Converts the AutomaticDataQualityDashboard configuration to a JSON formatted string. + + Returns: + str: A JSON formatted string representation of the dashboard configuration. + + """ + return json.dumps(self.to_dict(), indent=4) diff --git a/src/sagemaker/dashboard/model_quality_dashboard.py b/src/sagemaker/dashboard/model_quality_dashboard.py new file mode 100644 index 0000000000..6b4d791cef --- /dev/null +++ b/src/sagemaker/dashboard/model_quality_dashboard.py @@ -0,0 +1,220 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module the wrapper class for model quality dashboard. + +To be used to aid dashboard creation in ModelMonitor. +""" + +from __future__ import absolute_import +import json +from sagemaker.dashboard.dashboard_widgets import DashboardWidget, DashboardWidgetProperties + + +class AutomaticModelQualityDashboard: + """Represents a dashboard for automatic model quality metrics in Amazon SageMaker. + + Methods: + __init__(self, endpoint_name, monitoring_schedule_name, + batch_transform_input, problem_type, region_name): + Initializes an AutomaticModelQualityDashboard instance. + + _generate_widgets(self): + Generates widgets based on the specified problem type and metrics. + + to_dict(self): + Converts the dashboard instance to a dictionary representation. + + to_json(self): + Converts the dashboard instance to a JSON string. + """ + + MODEL_QUALITY_METRICS_ENDPOINT_NAMESPACE = ( + "{aws/sagemaker/Endpoints/model-metrics,Endpoint,MonitoringSchedule}" + ) + + MODEL_QUALITY_METRICS_BATCH_NAMESPACE = ( + "{aws/sagemaker/ModelMonitoring/model-metrics,MonitoringSchedule}" + ) + + REGRESSION_MODEL_QUALITY_METRICS = [ + # The outer list represents the graphs per line in cloudwatch + [ + # each tuple here contains the title and the metrics that are being graphed + ("Mean Squared Error", ["mse"]), + ("Root Mean Squared Error", ["rmse"]), + ], + [ + ("R-squared", ["r2"]), + ("Mean Absolute Error", ["mae"]), + ], + ] + + BINARY_CLASSIFICATION_MODEL_QUALITY_METRICS = [ + [ + ("Accuracy", ["accuracy", "accuracy_best_constant_classifier"]), + ("Precision", ["precision", "precision_best_constant_classifier"]), + ("Recall", ["recall", "recall_best_constant_classifier"]), + ], + [ + ("F0.5", ["f0_5", "f0_5_best_constant_classifier"]), + ("F1", ["f1", "f1_best_constant_classifier"]), + ("F2", ["f2", "f2_best_constant_classifier"]), + ], + [ + ("True Positive Rate", ["true_positive_rate"]), + ("True Negative Rate", ["true_negative_rate"]), + ("False Positive Rate", ["false_positive_rate"]), + ("False Negative Rate", ["false_negative_rate"]), + ], + [ + ("Area Under Precision-Recall Curve", ["au_prc"]), + ("Area Under ROC curve", ["auc"]), + ], + ] + + MULTICLASS_CLASSIFICATION_MODEL_QUALITY_METRICS = [ + [ + ("Accuracy", ["accuracy", "accuracy_best_constant_classifier"]), + ( + "Weighted Precision", + ["weighted_precision", "weighted_precision_best_constant_classifier"], + ), + ("Weighted Recall", ["weighted_recall", "weighted_recall_best_constant_classifier"]), + ], + [ + ("Weighted F0.5", ["weighted_f0_5", "weighted_f0_5_best_constant_classifier"]), + ("Weighted F1", ["weighted_f1", "weighted_f1_best_constant_classifier"]), + ("Weighted F2", ["weighted_f2", "weighted_f2_best_constant_classifier"]), + ], + ] + + def __init__( + self, + endpoint_name, + monitoring_schedule_name, + batch_transform_input, + problem_type, + region_name, + ): + """Initializes an AutomaticModelQualityDashboard instance. + + Args: + endpoint_name (str): Name of the SageMaker endpoint. + monitoring_schedule_name (str): Name of the monitoring schedule. + batch_transform_input (str): Batch transform input (can be None). + problem_type (str): Type of problem + ('Regression', 'BinaryClassification', 'MulticlassClassification'). + region_name (str): AWS region name. + """ + self.endpoint = endpoint_name + self.monitoring_schedule = monitoring_schedule_name + self.batch_transform = batch_transform_input + self.region = region_name + self.problem_type = problem_type + + self.dashboard = { + "widgets": self._generate_widgets(), + } + + def _generate_widgets(self): + """Generates widgets based on the specified problem type and metrics. + + Returns: + list: List of DashboardWidget instances representing each metric graph. + """ + list_of_widgets = [] + metrics_to_graph = None + if self.problem_type == "Regression": + metrics_to_graph = self.REGRESSION_MODEL_QUALITY_METRICS + elif self.problem_type == "BinaryClassification": + metrics_to_graph = self.BINARY_CLASSIFICATION_MODEL_QUALITY_METRICS + elif self.problem_type == "MulticlassClassification": + metrics_to_graph = self.MULTICLASS_CLASSIFICATION_MODEL_QUALITY_METRICS + else: + raise ValueError( + "Parameter problem_type is invalid. Valid options are " + "Regression, BinaryClassification, or MulticlassClassification." + ) + + for graphs_per_line in metrics_to_graph: + for graph in graphs_per_line: + graph_title = graph[0] + graph_metrics = ["%^" + str(metric) + "$%" for metric in graph[1]] + metrics_string = " OR ".join(graph_metrics) + if self.batch_transform is not None: + graph_properties = DashboardWidgetProperties( + view="timeSeries", + stacked=False, + metrics=[ + [ + { + "expression": ( + f"SEARCH( '{self.MODEL_QUALITY_METRICS_BATCH_NAMESPACE} " + f"{metrics_string} " + f'MonitoringSchedule="{self.monitoring_schedule}" \', ' + "'Average')" + ) + } + ] + ], + region=self.region, + title=graph_title, + ) + else: + graph_properties = DashboardWidgetProperties( + view="timeSeries", + stacked=False, + metrics=[ + [ + { + "expression": ( + f"SEARCH( '{self.MODEL_QUALITY_METRICS_ENDPOINT_NAMESPACE} " + f"{metrics_string} " + f'Endpoint="{self.endpoint}" ' + f'MonitoringSchedule="{self.monitoring_schedule}" \', ' + f"'Average')" + ) + } + ] + ], + region=self.region, + title=graph_title, + ) + list_of_widgets.append( + DashboardWidget( + height=8, + width=24 // len(graphs_per_line), + widget_type="metric", + properties=graph_properties, + ) + ) + + return list_of_widgets + + def to_dict(self): + """Converts the AutomaticModelQualityDashboard instance to a dictionary representation. + + Returns: + dict: Dictionary containing the dashboard widgets. + """ + return { + "widgets": [widget.to_dict() for widget in self.dashboard["widgets"]], + } + + def to_json(self): + """Converts the AutomaticModelQualityDashboard instance to a JSON string. + + Returns: + str: JSON string representation of the dashboard widgets. + """ + return json.dumps(self.to_dict(), indent=4) diff --git a/src/sagemaker/model_monitor/model_monitoring.py b/src/sagemaker/model_monitor/model_monitoring.py index 436377fea5..7343052596 100644 --- a/src/sagemaker/model_monitor/model_monitoring.py +++ b/src/sagemaker/model_monitor/model_monitoring.py @@ -24,6 +24,7 @@ import logging import uuid from typing import Union, Optional, Dict, List +import re import attr from six import string_types @@ -44,6 +45,7 @@ MONITORING_JOB_OUTPUT_KMS_KEY_ID_PATH, MONITORING_JOB_ROLE_ARN_PATH, ) +from sagemaker.dashboard.model_quality_dashboard import AutomaticModelQualityDashboard from sagemaker.exceptions import UnexpectedStatusException from sagemaker.model_monitor.monitoring_files import Constraints, ConstraintViolations, Statistics from sagemaker.model_monitor.monitoring_alert import ( @@ -67,6 +69,8 @@ from sagemaker.lineage._utils import get_resource_name_from_arn from sagemaker.model_monitor.cron_expression_generator import CronExpressionGenerator +from sagemaker.dashboard.data_quality_dashboard import AutomaticDataQualityDashboard + DEFAULT_REPOSITORY_NAME = "sagemaker-model-monitor-analyzer" STATISTICS_JSON_DEFAULT_FILE_NAME = "statistics.json" @@ -1535,6 +1539,78 @@ def _check_monitoring_schedule_cron_validity( _LOGGER.error(message) raise ValueError(message) + def _check_dashboard_validity_without_checking_in_use( + self, + monitor_schedule_name, + enable_cloudwatch_metrics=True, + dashboard_name=None, + ): + """Checks if the parameters are valid, without checking if dashboard name is taken + + Args: + monitor_schedule_name (str): Monitoring schedule name. + enable_cloudwatch_metrics (bool): Whether to publish cloudwatch metrics as part of + the baselining or monitoring jobs. + dashboard_name (str): The name to use when publishing dashboard + """ + if not enable_cloudwatch_metrics: + message = ( + "Could not create automatic dashboard. " + "Please set enable_cloudwatch_metrics to True." + ) + _LOGGER.error(message) + raise ValueError(message) + + if dashboard_name is None: + dashboard_name = monitor_schedule_name + dashboard_name_validation = bool(re.match(r"^[0-9A-Za-z\-_]{1,255}$", dashboard_name)) + if not dashboard_name_validation: + message = ( + f"Dashboard name {dashboard_name} is not a valid dashboard name. " + "Dashboard name can be at most 255 characters long " + "and valid characters in dashboard names include '0-9A-Za-z-_'." + ) + _LOGGER.error(message) + raise ValueError(message) + + def _check_automatic_dashboard_validity( + self, + cw_client, + monitor_schedule_name, + enable_cloudwatch_metrics=True, + dashboard_name=None, + ): + """Checks if the parameters provided to generate an automatic dashboard are valid + + Args: + monitor_schedule_name (str): Monitoring schedule name. + enable_cloudwatch_metrics (bool): Whether to publish cloudwatch metrics as part of + the baselining or monitoring jobs. + dashboard_name (str): The name to use when publishing dashboard + """ + + self._check_dashboard_validity_without_checking_in_use( + monitor_schedule_name=monitor_schedule_name, + enable_cloudwatch_metrics=enable_cloudwatch_metrics, + dashboard_name=dashboard_name, + ) + + # flag to check if dashboard with name dashboard_name exists already + dashboard_exists = True + try: + cw_client.get_dashboard(DashboardName=dashboard_name) + except ClientError as _: # noqa: F841 + dashboard_exists = False + + if dashboard_exists: + message = ( + f"Dashboard name {dashboard_name} is already in use. " + "Please provide a different dashboard name, or delete the already " + "existing dashboard." + ) + _LOGGER.error(message) + raise ValueError(message) + def _create_monitoring_schedule_from_job_definition( self, monitor_schedule_name, @@ -1945,6 +2021,8 @@ def create_monitoring_schedule( monitor_schedule_name=None, schedule_cron_expression=None, enable_cloudwatch_metrics=True, + enable_automatic_dashboard=False, + dashboard_name=None, batch_transform_input=None, data_analysis_start_time=None, data_analysis_end_time=None, @@ -1981,6 +2059,10 @@ def create_monitoring_schedule( expressions. Default: Daily. enable_cloudwatch_metrics (bool): Whether to publish cloudwatch metrics as part of the baselining or monitoring jobs. + enable_automatic_dashboard (bool): Whether to publish an automatic dashboard as part of + the baselining or monitoring jobs. + dashboard_name (str): Name to use for the published dashboard. When not provided, + defaults to monitoring schedule name. batch_transform_input (sagemaker.model_monitor.BatchTransformInput): Inputs to run the monitoring schedule on the batch transform (default: None) data_analysis_start_time (str): Start time for the data analysis window @@ -1988,6 +2070,7 @@ def create_monitoring_schedule( data_analysis_end_time (str): End time for the data analysis window for the one time monitoring schedule (NOW), e.g. "-PT1H" (default: None) """ + if self.job_definition_name is not None or self.monitoring_schedule_name is not None: message = ( "It seems that this object was already used to create an Amazon Model " @@ -2012,6 +2095,15 @@ def create_monitoring_schedule( data_analysis_end_time=data_analysis_end_time, ) + if enable_automatic_dashboard: + cw_client = self.sagemaker_session.boto_session.client("cloudwatch") + self._check_automatic_dashboard_validity( + cw_client=cw_client, + monitor_schedule_name=monitor_schedule_name, + enable_cloudwatch_metrics=enable_cloudwatch_metrics, + dashboard_name=dashboard_name, + ) + # create job definition monitor_schedule_name = self._generate_monitoring_schedule_name( schedule_name=monitor_schedule_name @@ -2069,6 +2161,25 @@ def create_monitoring_schedule( logger.exception(message) raise + if enable_automatic_dashboard: + if dashboard_name is None: + dashboard_name = monitor_schedule_name + if isinstance(endpoint_input, EndpointInput): + endpoint_name = endpoint_input.endpoint_name + else: + endpoint_name = endpoint_input + + cw_client = self.sagemaker_session.boto_session.client("cloudwatch") + cw_client.put_dashboard( + DashboardName=dashboard_name, + DashboardBody=AutomaticDataQualityDashboard( + endpoint_name=endpoint_name, + monitoring_schedule_name=monitor_schedule_name, + batch_transform_input=batch_transform_input, + region_name=self.sagemaker_session.boto_region_name, + ).to_json(), + ) + def update_monitoring_schedule( self, endpoint_input=None, @@ -2087,6 +2198,8 @@ def update_monitoring_schedule( env=None, network_config=None, enable_cloudwatch_metrics=None, + enable_automatic_dashboard=None, + dashboard_name=None, role=None, batch_transform_input=None, data_analysis_start_time=None, @@ -2131,6 +2244,10 @@ def update_monitoring_schedule( inter-container traffic, security group IDs, and subnets. enable_cloudwatch_metrics (bool): Whether to publish cloudwatch metrics as part of the baselining or monitoring jobs. + enable_automatic_dashboard (bool): Whether to publish an automatic dashboard as part of + the baselining or monitoring jobs. + dashboard_name (str): Name to use for the published dashboard. When not provided, + defaults to monitoring schedule name. role (str): An AWS IAM role name or ARN. The Amazon SageMaker jobs use this role. batch_transform_input (sagemaker.model_monitor.BatchTransformInput): Inputs to run the monitoring schedule on the batch transform (default: None) @@ -2150,6 +2267,16 @@ def update_monitoring_schedule( logger.error(message) raise ValueError(message) + # error checking for dashboard + if enable_automatic_dashboard: + cw_client = self.sagemaker_session.boto_session.client("cloudwatch") + self._check_automatic_dashboard_validity( + cw_client=cw_client, + monitor_schedule_name=self.monitoring_schedule_name, + enable_cloudwatch_metrics=enable_cloudwatch_metrics, + dashboard_name=dashboard_name, + ) + # check if this schedule is in v2 format and update as per v2 format if it is if self.job_definition_name is not None: self._update_data_quality_monitoring_schedule( @@ -2280,6 +2407,25 @@ def update_monitoring_schedule( self._wait_for_schedule_changes_to_apply() + if enable_automatic_dashboard: + if dashboard_name is None: + dashboard_name = self.monitoring_schedule_name + if isinstance(endpoint_input, EndpointInput): + endpoint_name = endpoint_input.endpoint_name + else: + endpoint_name = endpoint_input + + cw_client = self.sagemaker_session.boto_session.client("cloudwatch") + cw_client.put_dashboard( + DashboardName=dashboard_name, + DashboardBody=AutomaticDataQualityDashboard( + endpoint_name=endpoint_name, + monitoring_schedule_name=self.monitoring_schedule_name, + batch_transform_input=batch_transform_input, + region_name=self.sagemaker_session.boto_region_name, + ).to_json(), + ) + def _update_data_quality_monitoring_schedule( self, endpoint_input=None, @@ -3061,6 +3207,8 @@ def create_monitoring_schedule( monitor_schedule_name=None, schedule_cron_expression=None, enable_cloudwatch_metrics=True, + enable_automatic_dashboard=False, + dashboard_name=None, batch_transform_input=None, data_analysis_start_time=None, data_analysis_end_time=None, @@ -3092,6 +3240,10 @@ def create_monitoring_schedule( expressions. Default: Daily. enable_cloudwatch_metrics (bool): Whether to publish cloudwatch metrics as part of the baselining or monitoring jobs. + enable_automatic_dashboard (bool): Whether to publish an automatic dashboard as part of + the baselining or monitoring jobs. + dashboard_name (str): Name to use for the published dashboard. When not provided, + defaults to monitoring schedule name. batch_transform_input (sagemaker.model_monitor.BatchTransformInput): Inputs to run the monitoring schedule on the batch transform data_analysis_start_time (str): Start time for the data analysis window @@ -3131,6 +3283,15 @@ def create_monitoring_schedule( data_analysis_end_time=data_analysis_end_time, ) + if enable_automatic_dashboard: + cw_client = self.sagemaker_session.boto_session.client("cloudwatch") + self._check_automatic_dashboard_validity( + cw_client=cw_client, + monitor_schedule_name=monitor_schedule_name, + enable_cloudwatch_metrics=enable_cloudwatch_metrics, + dashboard_name=dashboard_name, + ) + # create job definition monitor_schedule_name = self._generate_monitoring_schedule_name( schedule_name=monitor_schedule_name @@ -3189,6 +3350,23 @@ def create_monitoring_schedule( logger.exception(message) raise + if enable_automatic_dashboard: + if isinstance(endpoint_input, EndpointInput): + endpoint_name = endpoint_input.endpoint_name + else: + endpoint_name = endpoint_input + cw_client = self.sagemaker_session.boto_session.client("cloudwatch") + cw_client.put_dashboard( + DashboardName=dashboard_name, + DashboardBody=AutomaticModelQualityDashboard( + endpoint_name=endpoint_name, + monitoring_schedule_name=monitor_schedule_name, + batch_transform_input=batch_transform_input, + problem_type=problem_type, + region_name=self.sagemaker_session.boto_region_name, + ).to_json(), + ) + def update_monitoring_schedule( self, endpoint_input=None, @@ -3200,6 +3378,8 @@ def update_monitoring_schedule( constraints=None, schedule_cron_expression=None, enable_cloudwatch_metrics=None, + enable_automatic_dashboard=None, + dashboard_name=None, role=None, instance_count=None, instance_type=None, @@ -3238,6 +3418,10 @@ def update_monitoring_schedule( expressions. Default: Daily. enable_cloudwatch_metrics (bool): Whether to publish cloudwatch metrics as part of the baselining or monitoring jobs. + enable_automatic_dashboard (bool): Whether to publish an automatic dashboard as part of + the baselining or monitoring jobs. + dashboard_name (str): Name to use for the published dashboard. When not provided, + defaults to monitoring schedule name. role (str): An AWS IAM role. The Amazon SageMaker jobs use this role. instance_count (int): The number of instances to run the jobs with. @@ -3287,6 +3471,15 @@ def update_monitoring_schedule( logger.error(message) raise ValueError(message) + if enable_automatic_dashboard: + cw_client = self.sagemaker_session.boto_session.client("cloudwatch") + self._check_automatic_dashboard_validity( + cw_client=cw_client, + monitor_schedule_name=self.monitoring_schedule_name, + enable_cloudwatch_metrics=enable_cloudwatch_metrics, + dashboard_name=dashboard_name, + ) + # Need to update schedule with a new job definition job_desc = self.sagemaker_session.sagemaker_client.describe_model_quality_job_definition( JobDefinitionName=self.job_definition_name @@ -3356,6 +3549,25 @@ def update_monitoring_schedule( logger.exception(message) raise + if enable_automatic_dashboard: + if dashboard_name is None: + dashboard_name = self.monitoring_schedule_name + if isinstance(endpoint_input, EndpointInput): + endpoint_name = endpoint_input.endpoint_name + else: + endpoint_name = endpoint_input + cw_client = self.sagemaker_session.boto_session.client("cloudwatch") + cw_client.put_dashboard( + DashboardName=dashboard_name, + DashboardBody=AutomaticModelQualityDashboard( + endpoint_name=endpoint_name, + monitoring_schedule_name=self.monitoring_schedule_name, + batch_transform_input=batch_transform_input, + problem_type=problem_type, + region_name=self.sagemaker_session.boto_region_name, + ).to_json(), + ) + def delete_monitoring_schedule(self): """Deletes the monitoring schedule and its job definition.""" super(ModelQualityMonitor, self).delete_monitoring_schedule() diff --git a/tests/integ/test_model_monitor.py b/tests/integ/test_model_monitor.py index 17ea70699b..de9cf4d6b2 100644 --- a/tests/integ/test_model_monitor.py +++ b/tests/integ/test_model_monitor.py @@ -63,6 +63,8 @@ encrypt_inter_container_traffic=True, ) ENABLE_CLOUDWATCH_METRICS = True +ENABLE_AUTOMATIC_DASHBOARD = True +DASHBOARD_NAME = "DASHBOARD NAME" DEFAULT_BASELINING_MAX_RUNTIME_IN_SECONDS = 86400 DEFAULT_EXECUTION_MAX_RUNTIME_IN_SECONDS = 3600 @@ -1082,6 +1084,131 @@ def test_default_monitor_create_and_update_schedule_config_with_customizations( assert len(predictor.list_monitors()) > 0 +@pytest.mark.skipif( + tests.integ.test_region() in tests.integ.NO_MODEL_MONITORING_REGIONS, + reason="ModelMonitoring is not yet supported in this region.", +) +def test_default_monitor_create_and_update_schedule_config_with_dashboards( + sagemaker_session, + predictor, + volume_kms_key, + output_kms_key, + updated_volume_kms_key, + updated_output_kms_key, +): + my_default_monitor = DefaultModelMonitor( + role=ROLE, + instance_count=INSTANCE_COUNT, + instance_type=INSTANCE_TYPE, + volume_size_in_gb=VOLUME_SIZE_IN_GB, + volume_kms_key=volume_kms_key, + output_kms_key=output_kms_key, + max_runtime_in_seconds=MAX_RUNTIME_IN_SECONDS, + sagemaker_session=sagemaker_session, + env=ENVIRONMENT, + tags=TAGS, + network_config=NETWORK_CONFIG, + ) + + output_s3_uri = os.path.join( + "s3://", + sagemaker_session.default_bucket(), + "integ-test-monitoring-output-bucket", + str(uuid.uuid4()), + ) + + statistics = Statistics.from_file_path( + statistics_file_path=os.path.join(tests.integ.DATA_DIR, "monitor/statistics.json"), + sagemaker_session=sagemaker_session, + ) + + constraints = Constraints.from_file_path( + constraints_file_path=os.path.join(tests.integ.DATA_DIR, "monitor/constraints.json"), + sagemaker_session=sagemaker_session, + ) + + my_default_monitor.create_monitoring_schedule( + endpoint_input=predictor.endpoint_name, + output_s3_uri=output_s3_uri, + statistics=statistics, + constraints=constraints, + schedule_cron_expression=CronExpressionGenerator.daily(), + enable_cloudwatch_metrics=ENABLE_CLOUDWATCH_METRICS, + enable_automatic_dashboard=ENABLE_AUTOMATIC_DASHBOARD, + dashboard_name=DASHBOARD_NAME, + ) + + schedule_description = my_default_monitor.describe_schedule() + _verify_default_monitoring_schedule( + sagemaker_session=sagemaker_session, + schedule_description=schedule_description, + statistics=statistics, + constraints=constraints, + output_kms_key=output_kms_key, + volume_kms_key=volume_kms_key, + network_config=NETWORK_CONFIG, + ) + + statistics = Statistics.from_file_path( + statistics_file_path=os.path.join(tests.integ.DATA_DIR, "monitor/statistics.json"), + sagemaker_session=sagemaker_session, + ) + + constraints = Constraints.from_file_path( + constraints_file_path=os.path.join(tests.integ.DATA_DIR, "monitor/constraints.json"), + sagemaker_session=sagemaker_session, + ) + + _wait_for_schedule_changes_to_apply(monitor=my_default_monitor) + + my_default_monitor.update_monitoring_schedule( + output_s3_uri=output_s3_uri, + statistics=statistics, + constraints=constraints, + schedule_cron_expression=CronExpressionGenerator.hourly(), + instance_count=UPDATED_INSTANCE_COUNT, + instance_type=UPDATED_INSTANCE_TYPE, + volume_size_in_gb=UPDATED_VOLUME_SIZE_IN_GB, + volume_kms_key=updated_volume_kms_key, + output_kms_key=updated_output_kms_key, + max_runtime_in_seconds=UPDATED_MAX_RUNTIME_IN_SECONDS, + env=UPDATED_ENVIRONMENT, + network_config=UPDATED_NETWORK_CONFIG, + enable_cloudwatch_metrics=ENABLE_CLOUDWATCH_METRICS, + enable_automatic_dashboard=ENABLE_AUTOMATIC_DASHBOARD, + role=UPDATED_ROLE, + ) + + _wait_for_schedule_changes_to_apply(my_default_monitor) + + schedule_description = my_default_monitor.describe_schedule() + _verify_default_monitoring_schedule( + sagemaker_session=sagemaker_session, + schedule_description=schedule_description, + statistics=statistics, + constraints=constraints, + output_kms_key=updated_output_kms_key, + volume_kms_key=updated_volume_kms_key, + cron_expression=CronExpressionGenerator.hourly(), + instant_count=UPDATED_INSTANCE_COUNT, + instant_type=UPDATED_INSTANCE_TYPE, + volume_size_in_gb=UPDATED_VOLUME_SIZE_IN_GB, + network_config=UPDATED_NETWORK_CONFIG, + max_runtime_in_seconds=UPDATED_MAX_RUNTIME_IN_SECONDS, + publish_cloudwatch_metrics="Disabled", + env_key=UPDATED_ENV_KEY_1, + env_value=UPDATED_ENV_VALUE_1, + ) + + _wait_for_schedule_changes_to_apply(monitor=my_default_monitor) + + my_default_monitor.stop_monitoring_schedule() + + _wait_for_schedule_changes_to_apply(monitor=my_default_monitor) + + assert len(predictor.list_monitors()) > 0 + + @pytest.mark.skipif( tests.integ.test_region() in tests.integ.NO_MODEL_MONITORING_REGIONS, reason="ModelMonitoring is not yet supported in this region.", diff --git a/tests/integ/test_model_quality_monitor.py b/tests/integ/test_model_quality_monitor.py index 1fafa96cfb..a588ae0998 100644 --- a/tests/integ/test_model_quality_monitor.py +++ b/tests/integ/test_model_quality_monitor.py @@ -176,6 +176,39 @@ def scheduled_model_quality_monitor( return model_quality_monitor +@pytest.fixture +def scheduled_model_quality_monitor_with_dashboard( + sagemaker_session, model_quality_monitor, endpoint_name, ground_truth_input +): + monitor_schedule_name = utils.unique_name_from_base("model-quality-monitor") + s3_uri_monitoring_output = os.path.join( + "s3://", + sagemaker_session.default_bucket(), + endpoint_name, + monitor_schedule_name, + "monitor_output", + ) + # To include attributes + endpoint_input = EndpointInput( + endpoint_name=endpoint_name, + destination=ENDPOINT_INPUT_LOCAL_PATH, + start_time_offset=START_TIME_OFFSET, + end_time_offset=END_TIME_OFFSET, + inference_attribute=INFERENCE_ATTRIBUTE, + ) + model_quality_monitor.create_monitoring_schedule( + monitor_schedule_name=monitor_schedule_name, + endpoint_input=endpoint_input, + ground_truth_input=ground_truth_input, + problem_type=PROBLEM_TYPE, + output_s3_uri=s3_uri_monitoring_output, + schedule_cron_expression=CRON, + enable_cloudwatch_metrics=True, + enable_automatic_dashboards=True, + ) + return model_quality_monitor + + @pytest.mark.skipif( tests.integ.test_region() in tests.integ.NO_MODEL_MONITORING_REGIONS, reason="ModelMonitoring is not yet supported in this region.", @@ -230,6 +263,65 @@ def test_model_quality_monitor( monitor.delete_monitoring_schedule() +@pytest.mark.skipif( + tests.integ.test_region() in tests.integ.NO_MODEL_MONITORING_REGIONS, + reason="ModelMonitoring is not yet supported in this region.", +) +def test_model_quality_monitor_with_dashboard( + sagemaker_session, + scheduled_model_quality_monitor_with_dashboard, + endpoint_name, + ground_truth_input, +): + monitor = scheduled_model_quality_monitor_with_dashboard + monitor._wait_for_schedule_changes_to_apply() + + # stop it as soon as possible to avoid any execution + monitor.stop_monitoring_schedule() + _verify_monitoring_schedule( + monitor=monitor, + schedule_status="Stopped", + ) + _verify_model_quality_job_description( + sagemaker_session=sagemaker_session, + monitor=monitor, + endpoint_name=endpoint_name, + ground_truth_input=ground_truth_input, + ) + + # attach to schedule + monitoring_schedule_name = monitor.monitoring_schedule_name + job_definition_name = monitor.job_definition_name + monitor = ModelQualityMonitor.attach( + monitor_schedule_name=monitor.monitoring_schedule_name, + sagemaker_session=sagemaker_session, + ) + assert monitor.monitoring_schedule_name == monitoring_schedule_name + assert monitor.job_definition_name == job_definition_name + + # update schedule + monitor.update_monitoring_schedule( + max_runtime_in_seconds=UPDATED_MAX_RUNTIME_IN_SECONDS, + schedule_cron_expression=UPDATED_CRON, + enable_automatic_dashboards=True, + ) + assert monitor.monitoring_schedule_name == monitoring_schedule_name + assert monitor.job_definition_name != job_definition_name + _verify_monitoring_schedule( + monitor=monitor, schedule_status="Scheduled", schedule_cron_expression=UPDATED_CRON + ) + _verify_model_quality_job_description( + sagemaker_session=sagemaker_session, + monitor=monitor, + endpoint_name=endpoint_name, + ground_truth_input=ground_truth_input, + max_runtime_in_seconds=UPDATED_MAX_RUNTIME_IN_SECONDS, + ) + + # delete schedule + monitor.delete_monitoring_schedule() + + @pytest.mark.slow_test @pytest.mark.skipif( tests.integ.test_region() in tests.integ.NO_MODEL_MONITORING_REGIONS, diff --git a/tests/unit/sagemaker/monitor/test_model_monitoring.py b/tests/unit/sagemaker/monitor/test_model_monitoring.py index d31b9f8527..744238dffe 100644 --- a/tests/unit/sagemaker/monitor/test_model_monitoring.py +++ b/tests/unit/sagemaker/monitor/test_model_monitoring.py @@ -19,6 +19,7 @@ import sagemaker import pytest from mock import Mock, patch, MagicMock +from botocore.exceptions import ClientError from sagemaker.model_monitor import ( ModelMonitor, @@ -464,11 +465,31 @@ "AlertStatus": "Ok", } +INVALID_DASHBOARD_NAME = "!@#$%^&*(){}?/[]" +VALID_DASHBOARD_NAME = "dashboard_1" +EXISTING_DASHBOARD_NAME = "dashboard_0" +NEW_DASHBOARD_NAME = "dashboard_2" + + +def mock_get_dashboard(DashboardName): + if DashboardName == EXISTING_DASHBOARD_NAME: + return {"DashboardName": DashboardName, "DashboardBody": "dummy_dashboard_body"} + raise ClientError( + error_response={ + "Error": {}, + }, + operation_name=f"Dashboard '{DashboardName}' not found.", + ) + # TODO-reinvent-2019: Continue to flesh these out. @pytest.fixture() def sagemaker_session(): boto_mock = Mock(name="boto_session", region_name=REGION) + boto_mock_client = Mock(name="cloudwatch_client") + boto_mock_client.put_dashboard = Mock() + boto_mock.client = boto_mock_client + session_mock = Mock( name="sagemaker_session", boto_session=boto_mock, @@ -477,6 +498,9 @@ def sagemaker_session(): local_mode=False, default_bucket_prefix=None, ) + + session_mock.boto_session.client("cloudwatch").get_dashboard = mock_get_dashboard + session_mock.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) session_mock.upload_data = Mock( name="upload_data", return_value="mocked_s3_uri_from_upload_data" @@ -850,6 +874,30 @@ def test_data_quality_monitor(data_quality_monitor, sagemaker_session): ) +def test_data_quality_monitor_with_dashboard(data_quality_monitor, sagemaker_session): + # create schedule + _test_data_quality_monitor_create_schedule( + data_quality_monitor=data_quality_monitor, + sagemaker_session=sagemaker_session, + constraints=CONSTRAINTS, + statistics=STATISTICS, + enable_cloudwatch_metrics=True, + enable_automatic_dashboard=True, + ) + + # update schedule + _test_data_quality_monitor_update_schedule( + data_quality_monitor=data_quality_monitor, + sagemaker_session=sagemaker_session, + ) + + # delete schedule + _test_data_quality_monitor_delete_schedule( + data_quality_monitor=data_quality_monitor, + sagemaker_session=sagemaker_session, + ) + + def test_data_quality_batch_transform_monitor(data_quality_monitor, sagemaker_session): # create schedule _test_data_quality_batch_transform_monitor_create_schedule( @@ -917,6 +965,43 @@ def test_data_quality_monitor_invalid_create(data_quality_monitor, sagemaker_ses ) +def test_data_quality_monitor_invalid_dashboard_create(data_quality_monitor, sagemaker_session): + # invalid: cannot create a monitoring schedule with an invalid dashboard name + with pytest.raises(ValueError): + _test_data_quality_monitor_create_schedule( + data_quality_monitor=data_quality_monitor, + sagemaker_session=sagemaker_session, + constraints=CONSTRAINTS, + enable_cloudwatch_metrics=True, + enable_automatic_dashboard=True, + dashboard_name=INVALID_DASHBOARD_NAME, + ) + + # invalid: cannot create a monitoring schedule when we set + # enable_automatic_dashboard to True but do not publish metrics + # to CW. + with pytest.raises(ValueError): + _test_data_quality_monitor_create_schedule( + data_quality_monitor=data_quality_monitor, + sagemaker_session=sagemaker_session, + constraints=CONSTRAINTS, + enable_cloudwatch_metrics=False, + enable_automatic_dashboard=True, + dashboard_name=VALID_DASHBOARD_NAME, + ) + + # invalid: cannot create a monitoring schedule with existing dashboard + with pytest.raises(ValueError): + _test_data_quality_monitor_create_schedule( + data_quality_monitor=data_quality_monitor, + sagemaker_session=sagemaker_session, + constraints=CONSTRAINTS, + enable_cloudwatch_metrics=True, + enable_automatic_dashboard=True, + dashboard_name=EXISTING_DASHBOARD_NAME, + ) + + def test_data_quality_monitor_creation_failure(data_quality_monitor, sagemaker_session): sagemaker_session.sagemaker_client.create_monitoring_schedule = Mock( side_effect=Exception("400") @@ -991,6 +1076,9 @@ def _test_data_quality_monitor_create_schedule( endpoint_input=EndpointInput( endpoint_name=ENDPOINT_NAME, destination=ENDPOINT_INPUT_LOCAL_PATH ), + enable_cloudwatch_metrics=True, + enable_automatic_dashboard=False, + dashboard_name=None, ): # for endpoint input data_quality_monitor.create_monitoring_schedule( @@ -1002,6 +1090,9 @@ def _test_data_quality_monitor_create_schedule( statistics=statistics, monitor_schedule_name=SCHEDULE_NAME, schedule_cron_expression=CRON_HOURLY, + enable_cloudwatch_metrics=enable_cloudwatch_metrics, + enable_automatic_dashboard=enable_automatic_dashboard, + dashboard_name=dashboard_name, ) # validation @@ -1030,6 +1121,9 @@ def _test_data_quality_batch_transform_monitor_create_schedule( destination=SCHEDULE_DESTINATION, dataset_format=MonitoringDatasetFormat.csv(header=False), ), + enable_cloudwatch_metrics=True, + enable_automatic_dashboard=False, + dashboard_name=None, ): # for batch transform input data_quality_monitor.create_monitoring_schedule( @@ -1041,6 +1135,9 @@ def _test_data_quality_batch_transform_monitor_create_schedule( statistics=statistics, monitor_schedule_name=SCHEDULE_NAME, schedule_cron_expression=CRON_HOURLY, + enable_cloudwatch_metrics=enable_cloudwatch_metrics, + enable_automatic_dashboard=enable_automatic_dashboard, + dashboard_name=dashboard_name, ) # validation @@ -1494,6 +1591,28 @@ def test_model_quality_monitor(model_quality_monitor, sagemaker_session): ) +def test_model_quality_monitor_with_dashboard(model_quality_monitor, sagemaker_session): + # create schedule + _test_model_quality_monitor_create_schedule( + model_quality_monitor=model_quality_monitor, + sagemaker_session=sagemaker_session, + constraints=CONSTRAINTS, + enable_automatic_dashboard=True, + dashboard_name=NEW_DASHBOARD_NAME, + ) + + # update schedule + _test_model_quality_monitor_update_schedule( + model_quality_monitor=model_quality_monitor, + sagemaker_session=sagemaker_session, + ) + + _test_model_quality_monitor_delete_schedule( + model_quality_monitor=model_quality_monitor, + sagemaker_session=sagemaker_session, + ) + + def test_model_quality_batch_transform_monitor(model_quality_monitor, sagemaker_session): # create schedule _test_model_quality_monitor_batch_transform_create_schedule( @@ -1558,6 +1677,43 @@ def test_model_quality_monitor_invalid_create(model_quality_monitor, sagemaker_s ) +def test_model_quality_monitor_invalid_dashboard_create(model_quality_monitor, sagemaker_session): + # invalid: cannot create a monitoring schedule with an invalid dashboard name + with pytest.raises(ValueError): + _test_model_quality_monitor_create_schedule( + model_quality_monitor=model_quality_monitor, + sagemaker_session=sagemaker_session, + constraints=CONSTRAINTS, + enable_cloudwatch_metrics=True, + enable_automatic_dashboard=True, + dashboard_name=INVALID_DASHBOARD_NAME, + ) + + # invalid: cannot create a monitoring schedule when we set + # enable_automatic_dashboard to True but do not publish metrics + # to CW. + with pytest.raises(ValueError): + _test_model_quality_monitor_create_schedule( + model_quality_monitor=model_quality_monitor, + sagemaker_session=sagemaker_session, + constraints=CONSTRAINTS, + enable_cloudwatch_metrics=False, + enable_automatic_dashboard=True, + dashboard_name=VALID_DASHBOARD_NAME, + ) + + # invalid: cannot create a monitoring schedule with existing dashboard + with pytest.raises(ValueError): + _test_model_quality_monitor_create_schedule( + model_quality_monitor=model_quality_monitor, + sagemaker_session=sagemaker_session, + constraints=CONSTRAINTS, + enable_cloudwatch_metrics=True, + enable_automatic_dashboard=True, + dashboard_name=EXISTING_DASHBOARD_NAME, + ) + + def test_model_quality_monitor_creation_failure(model_quality_monitor, sagemaker_session): sagemaker_session.sagemaker_client.create_monitoring_schedule = Mock( side_effect=Exception("400") @@ -1639,6 +1795,9 @@ def _test_model_quality_monitor_create_schedule( probability_attribute=PROBABILITY_ATTRIBUTE, probability_threshold_attribute=PROBABILITY_THRESHOLD_ATTRIBUTE, ), + enable_cloudwatch_metrics=True, + enable_automatic_dashboard=False, + dashboard_name=None, ): model_quality_monitor.create_monitoring_schedule( endpoint_input=endpoint_input, @@ -1650,6 +1809,9 @@ def _test_model_quality_monitor_create_schedule( constraints=constraints, monitor_schedule_name=SCHEDULE_NAME, schedule_cron_expression=CRON_HOURLY, + enable_cloudwatch_metrics=enable_cloudwatch_metrics, + enable_automatic_dashboard=enable_automatic_dashboard, + dashboard_name=dashboard_name, ) # validation diff --git a/tests/unit/test_dashboard_methods.py b/tests/unit/test_dashboard_methods.py new file mode 100644 index 0000000000..1ec0837f43 --- /dev/null +++ b/tests/unit/test_dashboard_methods.py @@ -0,0 +1,418 @@ +from __future__ import absolute_import + +from sagemaker.dashboard.data_quality_dashboard import AutomaticDataQualityDashboard +from sagemaker.dashboard.dashboard_variables import DashboardVariable +from sagemaker.dashboard.dashboard_widgets import DashboardWidget, DashboardWidgetProperties + + +def test_variable_to_dict(): + var = DashboardVariable( + variable_type="property", + variable_property="Feature", + inputType="select", + variable_id="Feature", + label="Feature", + search="{aws/sagemaker/Endpoints/data-metrics,Endpoint,Feature,MonitoringSchedule}", + populateFrom="Feature", + ) + expected_dict = { + "type": "property", + "property": "Feature", + "inputType": "select", + "id": "Feature", + "label": "Feature", + "search": "{aws/sagemaker/Endpoints/data-metrics,Endpoint,Feature,MonitoringSchedule}", + "populateFrom": "Feature", + } + assert var.to_dict() == expected_dict + + +def test_widget_properties_to_dict(): + widget_properties = DashboardWidgetProperties( + view="timeSeries", + stacked=False, + metrics=[ + [ + { + "expression": ( + "SEARCH(" + " 'aws/sagemaker/Endpoints/data-metrics,Endpoint,Feature,MonitoringSchedule " + "%^(feature_null_|feature_non_null_).*% ', " + "'Average')" + ) + } + ] + ], + region="us-east-1", + title="Missing Data Counts", + ) + expected_dict = { + "view": "timeSeries", + "stacked": False, + "metrics": [ + [ + { + "expression": ( + "SEARCH( " + "'aws/sagemaker/Endpoints/data-metrics,Endpoint,Feature,MonitoringSchedule " + "%^(feature_null_|feature_non_null_).*% ', " + "'Average')" + ) + } + ] + ], + "region": "us-east-1", + "title": "Missing Data Counts", + } + assert widget_properties.to_dict() == expected_dict + + +def test_widget_to_dict(): + widget_properties = DashboardWidgetProperties( + view="timeSeries", + stacked=False, + metrics=[ + [ + { + "expression": ( + "SEARCH( 'aws/sagemaker/Endpoints/data-metrics,Endpoint,Feature,MonitoringSchedule " + "%^(feature_null_|feature_non_null_).*% ', 'Average')" + ) + } + ] + ], + region="us-east-1", + title="Missing Data Counts", + ) + widget = DashboardWidget(height=8, width=12, widget_type="metric", properties=widget_properties) + expected_dict = { + "height": 8, + "width": 12, + "type": "metric", + "properties": { + "view": "timeSeries", + "stacked": False, + "metrics": [ + [ + { + "expression": ( + "SEARCH( 'aws/sagemaker/Endpoints/data-metrics,Endpoint,Feature,MonitoringSchedule " + "%^(feature_null_|feature_non_null_).*% ', " + "'Average')" + ) + } + ] + ], + "region": "us-east-1", + "title": "Missing Data Counts", + }, + } + assert widget.to_dict() == expected_dict + + +def test_automatic_data_quality_dashboard_endpoint(): + mock_generate_variables = [ + DashboardVariable( + variable_type="property", + variable_property="Feature", + inputType="select", + variable_id="Feature", + label="Feature", + search="{aws/sagemaker/Endpoints/data-metrics,Endpoint,Feature,MonitoringSchedule}" + + ' Endpoint="endpoint"' + + ' MonitoringSchedule="monitoring_schedule" ', + populateFrom="Feature", + ) + ] + mock_generate_type_counts_widget = DashboardWidget( + height=8, + width=12, + widget_type="metric", + properties=DashboardWidgetProperties( + view="timeSeries", + stacked=False, + region="us-west-2", + metrics=[ + [ + { + "expression": ( + "SEARCH( '{aws/sagemaker/Endpoints/data-metrics,Endpoint,Feature,MonitoringSchedule} " + "%^feature_fractional_counts_.*% OR " + "%^feature_integral_counts_.*% OR " + "%^feature_string_counts_.*% OR " + "%^feature_boolean_counts_.*% OR " + "%^feature_unknown_counts_.*% " + 'Endpoint="endpoint" Feature="_" MonitoringSchedule="monitoring_schedule" \', ' + "'Average')" + ) + } + ] + ], + title="Type Counts", + ), + ) + mock_generate_null_counts_widget = DashboardWidget( + height=8, + width=12, + widget_type="metric", + properties=DashboardWidgetProperties( + view="timeSeries", + stacked=False, + region="us-west-2", + metrics=[ + [ + { + "expression": ( + "SEARCH( '{aws/sagemaker/Endpoints/data-metrics,Endpoint,Feature,MonitoringSchedule} " + "%^feature_null_.*% OR " + "%^feature_non_null_.*% " + 'Endpoint="endpoint" Feature="_" MonitoringSchedule="monitoring_schedule" \', ' + "'Average')" + ) + } + ] + ], + title="Missing Data Counts", + ), + ) + mock_generate_estimated_unique_values_widget = DashboardWidget( + height=8, + width=12, + widget_type="metric", + properties=DashboardWidgetProperties( + view="timeSeries", + stacked=False, + region="us-west-2", + metrics=[ + [ + { + "expression": ( + "SEARCH( '{aws/sagemaker/Endpoints/data-metrics,Endpoint,Feature,MonitoringSchedule} " + "%^feature_estimated_unique_values_.*% " + 'Endpoint="endpoint" Feature="_" MonitoringSchedule="monitoring_schedule" \', ' + "'Average')" + ) + } + ] + ], + title="Estimated Unique Values", + ), + ) + mock_generate_completeness_widget = DashboardWidget( + height=8, + width=12, + widget_type="metric", + properties=DashboardWidgetProperties( + view="timeSeries", + stacked=False, + region="us-west-2", + metrics=[ + [ + { + "expression": ( + "SEARCH( '{aws/sagemaker/Endpoints/data-metrics,Endpoint,Feature,MonitoringSchedule} " + "%^feature_completeness_.*% " + 'Endpoint="endpoint" Feature="_" MonitoringSchedule="monitoring_schedule" \', ' + "'Average')" + ) + } + ] + ], + title="Completeness", + ), + ) + + mock_generate_baseline_drift_widget = DashboardWidget( + height=8, + width=12, + widget_type="metric", + properties=DashboardWidgetProperties( + view="timeSeries", + stacked=False, + region="us-west-2", + metrics=[ + [ + { + "expression": ( + "SEARCH( '{aws/sagemaker/Endpoints/data-metrics,Endpoint,Feature,MonitoringSchedule} " + "%^feature_baseline_drift_.*% " + 'Endpoint="endpoint" Feature="_" MonitoringSchedule="monitoring_schedule" \', ' + "'Average')" + ) + } + ] + ], + title="Baseline Drift", + ), + ) + + dashboard = AutomaticDataQualityDashboard("endpoint", "monitoring_schedule", None, "us-west-2") + + expected_dashboard = { + "variables": [var.to_dict() for var in mock_generate_variables], + "widgets": [ + widget.to_dict() + for widget in [ + mock_generate_type_counts_widget, + mock_generate_null_counts_widget, + mock_generate_estimated_unique_values_widget, + mock_generate_completeness_widget, + mock_generate_baseline_drift_widget, + ] + ], + } + assert dashboard.to_dict() == expected_dashboard + + +def test_automatic_data_quality_dashboard_batch_transform(): + mock_generate_variables = [ + DashboardVariable( + variable_type="property", + variable_property="Feature", + inputType="select", + variable_id="Feature", + label="Feature", + search="{aws/sagemaker/ModelMonitoring/data-metrics,Feature,MonitoringSchedule}" + + ' MonitoringSchedule="monitoring_schedule" ', + populateFrom="Feature", + ) + ] + mock_generate_type_counts_widget = DashboardWidget( + height=8, + width=12, + widget_type="metric", + properties=DashboardWidgetProperties( + view="timeSeries", + stacked=False, + metrics=[ + [ + { + "expression": ( + "SEARCH( '{aws/sagemaker/ModelMonitoring/data-metrics,Feature,MonitoringSchedule} " + "%^feature_fractional_counts_.*% OR " + "%^feature_integral_counts_.*% OR " + "%^feature_string_counts_.*% OR " + "%^feature_boolean_counts_.*% OR " + "%^feature_unknown_counts_.*% " + 'Feature="_" MonitoringSchedule="monitoring_schedule" \', ' + "'Average')" + ) + } + ] + ], + region="us-west-2", + title="Type Counts", + ), + ) + mock_generate_null_counts_widget = DashboardWidget( + height=8, + width=12, + widget_type="metric", + properties=DashboardWidgetProperties( + view="timeSeries", + stacked=False, + metrics=[ + [ + { + "expression": ( + "SEARCH( '{aws/sagemaker/ModelMonitoring/data-metrics,Feature,MonitoringSchedule} " + "%^feature_null_.*% OR " + "%^feature_non_null_.*% " + 'Feature="_" MonitoringSchedule="monitoring_schedule" \', ' + "'Average')" + ) + } + ] + ], + region="us-west-2", + title="Missing Data Counts", + ), + ) + mock_generate_estimated_unique_values_widget = DashboardWidget( + height=8, + width=12, + widget_type="metric", + properties=DashboardWidgetProperties( + view="timeSeries", + stacked=False, + metrics=[ + [ + { + "expression": ( + "SEARCH( '{aws/sagemaker/ModelMonitoring/data-metrics,Feature,MonitoringSchedule} " + "%^feature_estimated_unique_values_.*% " + 'Feature="_" MonitoringSchedule="monitoring_schedule" \', ' + "'Average')" + ) + } + ] + ], + region="us-west-2", + title="Estimated Unique Values", + ), + ) + mock_generate_completeness_widget = DashboardWidget( + height=8, + width=12, + widget_type="metric", + properties=DashboardWidgetProperties( + view="timeSeries", + stacked=False, + metrics=[ + [ + { + "expression": ( + "SEARCH( '{aws/sagemaker/ModelMonitoring/data-metrics,Feature,MonitoringSchedule} " + "%^feature_completeness_.*% " + 'Feature="_" MonitoringSchedule="monitoring_schedule" \', ' + "'Average')" + ) + } + ] + ], + region="us-west-2", + title="Completeness", + ), + ) + mock_generate_baseline_drift_widget = DashboardWidget( + height=8, + width=12, + widget_type="metric", + properties=DashboardWidgetProperties( + view="timeSeries", + stacked=False, + metrics=[ + [ + { + "expression": ( + "SEARCH( '{aws/sagemaker/ModelMonitoring/data-metrics,Feature,MonitoringSchedule} " + "%^feature_baseline_drift_.*% " + 'Feature="_" MonitoringSchedule="monitoring_schedule" \', ' + "'Average')" + ) + } + ] + ], + region="us-west-2", + title="Baseline Drift", + ), + ) + + # Pass any non None value for batch transform input to check if the dashboard correctly uses the other namespace + dashboard = AutomaticDataQualityDashboard(None, "monitoring_schedule", True, "us-west-2") + + expected_dashboard = { + "variables": [var.to_dict() for var in mock_generate_variables], + "widgets": [ + widget.to_dict() + for widget in [ + mock_generate_type_counts_widget, + mock_generate_null_counts_widget, + mock_generate_estimated_unique_values_widget, + mock_generate_completeness_widget, + mock_generate_baseline_drift_widget, + ] + ], + } + + assert dashboard.to_dict() == expected_dashboard