|
8 | 8 | # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF |
9 | 9 | # ANY KIND, either express or implied. See the License for the specific |
10 | 10 | # language governing permissions and limitations under the License. |
11 | | -""" |
12 | | -This module provides visualization capabilities for SageMaker hyperparameter tuning jobs. |
| 11 | +"""This module provides visualization capabilities for SageMaker hyperparameter tuning jobs. |
13 | 12 |
|
14 | 13 | It contains utilities to create interactive visualizations of hyperparameter tuning results |
15 | 14 | using Altair charts. The module enables users to analyze and understand the performance |
@@ -83,8 +82,7 @@ def visualize_tuning_job( |
83 | 82 | trials_only: bool = False, |
84 | 83 | advanced: bool = False, |
85 | 84 | ) -> Union[alt.Chart, Tuple[alt.Chart, pd.DataFrame, pd.DataFrame]]: |
86 | | - """ |
87 | | - Visualize SageMaker hyperparameter tuning jobs. |
| 85 | + """Visualize SageMaker hyperparameter tuning jobs. |
88 | 86 |
|
89 | 87 | Args: |
90 | 88 | tuning_jobs: Single tuning job or list of tuning jobs (name or HyperparameterTuner object) |
@@ -147,8 +145,7 @@ def create_charts( |
147 | 145 | color_trials: bool = False, |
148 | 146 | advanced: bool = False, |
149 | 147 | ) -> alt.Chart: |
150 | | - """ |
151 | | - Create visualization charts for hyperparameter tuning results. |
| 148 | + """Create visualization charts for hyperparameter tuning results. |
152 | 149 |
|
153 | 150 | Args: |
154 | 151 | trials_df: DataFrame containing trials data |
@@ -240,7 +237,8 @@ def create_charts( |
240 | 237 | # If we have multiple tuning jobs, we also want to be able |
241 | 238 | # to discriminate based on the individual tuning job, so |
242 | 239 | # we just treat them as an additional tuning parameter |
243 | | - tuning_parameters = tuning_parameters.copy() + (["TuningJobName"] if multiple_tuning_jobs else []) |
| 240 | + tuning_job_param = ["TuningJobName"] if multiple_tuning_jobs else [] |
| 241 | + tuning_parameters = tuning_parameters.copy() + tuning_job_param |
244 | 242 |
|
245 | 243 | # If we use early stopping and at least some jobs were |
246 | 244 | # stopped early, we want to be able to discriminate |
@@ -331,7 +329,7 @@ def render_detail_charts(): |
331 | 329 | bandwidth=0.01, |
332 | 330 | groupby=[tuning_parameter], |
333 | 331 | # https://github.com/vega/altair/issues/3203#issuecomment-2141558911 |
334 | | - # Specifying extent no longer necessary (>5.1.2). Leaving the work around in it for now. |
| 332 | + # Specifying extent no longer necessary (>5.1.2). |
335 | 333 | extent=[ |
336 | 334 | trials_df[objective_name].min(), |
337 | 335 | trials_df[objective_name].max(), |
@@ -612,7 +610,7 @@ def render_progress_chart(): |
612 | 610 |
|
613 | 611 |
|
614 | 612 | def _clean_parameter_name(s): |
615 | | - """ Helper method to ensure proper parameter name characters for altair 5+ """ |
| 613 | + """Helper method to ensure proper parameter name characters for altair 5+""" |
616 | 614 | return s.replace(":", "_").replace(".", "_") |
617 | 615 |
|
618 | 616 |
|
@@ -664,8 +662,10 @@ def _prepare_consolidated_df(trials_df): |
664 | 662 |
|
665 | 663 |
|
666 | 664 | def _get_df(tuning_job_name, filter_out_stopped=False): |
667 | | - """Retrieves hyperparameter tuning job results and returns preprocessed DataFrame with |
668 | | - tuning metrics and parameters.""" |
| 665 | + """Retrieves hyperparameter tuning job results and returns preprocessed DataFrame. |
| 666 | +
|
| 667 | + Returns a DataFrame containing tuning metrics and parameters for the specified job. |
| 668 | + """ |
669 | 669 |
|
670 | 670 | tuner = sagemaker.HyperparameterTuningJobAnalytics(tuning_job_name) |
671 | 671 |
|
@@ -707,10 +707,12 @@ def _get_df(tuning_job_name, filter_out_stopped=False): |
707 | 707 | # A float then? |
708 | 708 | df[parameter_name] = df[parameter_name].astype(float) |
709 | 709 |
|
710 | | - except Exception: |
711 | | - # Trouble, as this was not a number just pretending to be a string, but an actual string with |
712 | | - # characters. Leaving the value untouched |
713 | | - # Ex: Caught exception could not convert string to float: 'sqrt' <class 'ValueError'> |
| 710 | + except (ValueError, TypeError, AttributeError): |
| 711 | + # Catch exceptions that might occur during string manipulation or type conversion |
| 712 | + # - ValueError: Could not convert string to float/int |
| 713 | + # - TypeError: Object doesn't support the operation |
| 714 | + # - AttributeError: Object doesn't have replace method |
| 715 | + # Leaving the value untouched |
714 | 716 | pass |
715 | 717 |
|
716 | 718 | return df |
@@ -747,7 +749,7 @@ def get_job_analytics_data(tuning_job_names): |
747 | 749 | tuning_job_names (str or list): Single tuning job name or list of names/tuner objects. |
748 | 750 |
|
749 | 751 | Returns: |
750 | | - tuple: (DataFrame with training results, tuned parameters list, objective name, is_minimize flag). |
| 752 | + tuple: (DataFrame with training results, tuned params list, objective name, is_minimize). |
751 | 753 |
|
752 | 754 | Raises: |
753 | 755 | ValueError: If tuning jobs have different objectives or optimization directions. |
@@ -828,16 +830,18 @@ def get_job_analytics_data(tuning_job_names): |
828 | 830 | if isinstance(val, str) and val.startswith('"'): |
829 | 831 | try: |
830 | 832 | df[column_name] = df[column_name].apply(lambda x: int(x.replace('"', ""))) |
831 | | - except: # noqa: E722 nosec b110 if we fail, we just continue with what we had |
| 833 | + except (ValueError, TypeError, AttributeError): |
| 834 | + # noqa: E722 nosec b110 if we fail, we just continue with what we had |
832 | 835 | pass # Value is not an int, but a string |
833 | 836 |
|
834 | 837 | df = df.sort_values("FinalObjectiveValue", ascending=is_minimize) |
835 | 838 | df[objective_name] = df.pop("FinalObjectiveValue") |
836 | 839 |
|
837 | 840 | # Fix potential issue with dates represented as objects, instead of a timestamp |
838 | 841 | # This can in other cases lead to: |
839 | | - # https://www.markhneedham.com/blog/2020/01/10/altair-typeerror-object-type-date-not-json-serializable/ |
840 | | - # Have only observed this for TrainingEndTime, but will be on the lookout dfor TrainingStartTime as well now |
| 842 | + # https://www.markhneedham.com/blog/2020/01/10/altair-typeerror-object-type- |
| 843 | + # date-not-json-serializable/ |
| 844 | + # Seen this for TrainingEndTime, but will watch TrainingStartTime as well now. |
841 | 845 | df["TrainingEndTime"] = pd.to_datetime(df["TrainingEndTime"]) |
842 | 846 | df["TrainingStartTime"] = pd.to_datetime(df["TrainingStartTime"]) |
843 | 847 |
|
|
0 commit comments