Skip to content

Commit d5df045

Browse files
authored
skip invalid numbers when logging a metric or parameter (#161)
1 parent dea989b commit d5df045

File tree

2 files changed

+49
-3
lines changed

2 files changed

+49
-3
lines changed

src/smexperiments/tracker.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
import logging
2020
import botocore
2121
import json
22+
from math import isnan, isinf
23+
from numbers import Number
2224
from smexperiments._utils import get_module
2325
from os.path import join
2426

@@ -231,7 +233,8 @@ def log_parameter(self, name, value):
231233
name (str): The name of the parameter
232234
value (str or numbers.Number): The value of the parameter
233235
"""
234-
self.trial_component.parameters[name] = value
236+
if self._is_input_valid("parameter", name, value):
237+
self.trial_component.parameters[name] = value
235238

236239
def log_parameters(self, parameters):
237240
"""Record a collection of parameter values for this trial component.
@@ -245,7 +248,10 @@ def log_parameters(self, parameters):
245248
Args:
246249
parameters (dict[str, str or numbers.Number]): The parameters to record.
247250
"""
248-
self.trial_component.parameters.update(parameters)
251+
filtered_parameters = {
252+
key: value for (key, value) in parameters.items() if self._is_input_valid("parameter", key, value)
253+
}
254+
self.trial_component.parameters.update(filtered_parameters)
249255

250256
def log_input(self, name, value, media_type=None):
251257
"""Record a single input artifact for this trial component.
@@ -402,7 +408,8 @@ def log_metric(self, metric_name, value, timestamp=None, iteration_number=None):
402408
AttributeError: If the metrics writer is not initialized.
403409
"""
404410
try:
405-
self._metrics_writer.log_metric(metric_name, value, timestamp, iteration_number)
411+
if self._is_input_valid("metric", metric_name, value):
412+
self._metrics_writer.log_metric(metric_name, value, timestamp, iteration_number)
406413
except AttributeError:
407414
if not self._metrics_writer:
408415
if not self._warned_on_metrics:
@@ -654,6 +661,12 @@ def _log_graph_artifact(self, name, data, graph_type, output_artifact):
654661
else:
655662
self._lineage_artifact_tracker.add_input_artifact(artifact_name, s3_uri, etag, graph_type)
656663

664+
def _is_input_valid(self, input_type, field_name, field_value):
665+
if isinstance(field_value, Number) and (isnan(field_value) or isinf(field_value)):
666+
logging.warning(f"Failed to log {input_type} {field_name}. Received invalid value: {field_value}.")
667+
return False
668+
return True
669+
657670
def __enter__(self):
658671
"""Updates the start time of the tracked trial component.
659672

tests/unit/test_tracker.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
import tempfile
1717
import os
1818
import datetime
19+
from math import nan, inf
20+
import numpy as np
1921
from smexperiments import api_types, tracker, trial_component, _utils, _environment
2022
import pandas as pd
2123

@@ -171,6 +173,11 @@ def test_log_parameter(under_test):
171173
assert under_test.trial_component.parameters["whizz"] == 1
172174

173175

176+
def test_log_parameter_skip_invalid_value(under_test):
177+
under_test.log_parameter("key", nan)
178+
assert "key" not in under_test.trial_component.parameters
179+
180+
174181
def test_enter(under_test):
175182
under_test.__enter__()
176183
assert isinstance(under_test.trial_component.start_time, datetime.datetime)
@@ -213,6 +220,11 @@ def test_log_parameters(under_test):
213220
assert under_test.trial_component.parameters == {"a": "b", "c": "d", "e": 5}
214221

215222

223+
def test_log_parameters_skip_invalid_values(under_test):
224+
under_test.log_parameters({"a": "b", "c": "d", "e": 5, "f": nan})
225+
assert under_test.trial_component.parameters == {"a": "b", "c": "d", "e": 5}
226+
227+
216228
def test_log_input(under_test):
217229
under_test.log_input("foo", "baz", "text/text")
218230
assert under_test.trial_component.input_artifacts == {
@@ -233,6 +245,11 @@ def test_log_metric(under_test):
233245
under_test._metrics_writer.log_metric.assert_called_with("foo", 1.0, 1, now)
234246

235247

248+
def test_log_metric_skip_invalid_value(under_test):
249+
under_test.log_metric(None, nan, None, None)
250+
assert not under_test._metrics_writer.log_metric.called
251+
252+
236253
def test_log_metric_attribute_error(under_test):
237254
now = datetime.datetime.now()
238255

@@ -630,3 +647,19 @@ def test_log_roc_curve(under_test):
630647
)
631648

632649
under_test._lineage_artifact_tracker.add_input_artifact("TestROCCurve", "s3uri_value", "etag_value", "ROCCurve")
650+
651+
652+
@pytest.mark.parametrize(
653+
"metric_value",
654+
[1.3, "nan", "inf", "-inf", None],
655+
)
656+
def test_is_input_valid(under_test, metric_value):
657+
assert under_test._is_input_valid("metric", "Name", metric_value)
658+
659+
660+
@pytest.mark.parametrize(
661+
"metric_value",
662+
[nan, inf, -inf],
663+
)
664+
def test__is_input_valid_false(under_test, metric_value):
665+
assert not under_test._is_input_valid("parameter", "Name", metric_value)

0 commit comments

Comments
 (0)