1616import tempfile
1717import os
1818import datetime
19+ from math import nan , inf
20+ import numpy as np
1921from smexperiments import api_types , tracker , trial_component , _utils , _environment
2022import pandas as pd
2123
@@ -171,6 +173,11 @@ def test_log_parameter(under_test):
171173 assert under_test .trial_component .parameters ["whizz" ] == 1
172174
173175
176+ def test_log_parameter_skip_invalid_value (under_test ):
177+ under_test .log_parameter ("key" , nan )
178+ assert "key" not in under_test .trial_component .parameters
179+
180+
174181def test_enter (under_test ):
175182 under_test .__enter__ ()
176183 assert isinstance (under_test .trial_component .start_time , datetime .datetime )
@@ -213,6 +220,11 @@ def test_log_parameters(under_test):
213220 assert under_test .trial_component .parameters == {"a" : "b" , "c" : "d" , "e" : 5 }
214221
215222
223+ def test_log_parameters_skip_invalid_values (under_test ):
224+ under_test .log_parameters ({"a" : "b" , "c" : "d" , "e" : 5 , "f" : nan })
225+ assert under_test .trial_component .parameters == {"a" : "b" , "c" : "d" , "e" : 5 }
226+
227+
216228def test_log_input (under_test ):
217229 under_test .log_input ("foo" , "baz" , "text/text" )
218230 assert under_test .trial_component .input_artifacts == {
@@ -233,6 +245,11 @@ def test_log_metric(under_test):
233245 under_test ._metrics_writer .log_metric .assert_called_with ("foo" , 1.0 , 1 , now )
234246
235247
248+ def test_log_metric_skip_invalid_value (under_test ):
249+ under_test .log_metric (None , nan , None , None )
250+ assert not under_test ._metrics_writer .log_metric .called
251+
252+
236253def test_log_metric_attribute_error (under_test ):
237254 now = datetime .datetime .now ()
238255
@@ -630,3 +647,19 @@ def test_log_roc_curve(under_test):
630647 )
631648
632649 under_test ._lineage_artifact_tracker .add_input_artifact ("TestROCCurve" , "s3uri_value" , "etag_value" , "ROCCurve" )
650+
651+
652+ @pytest .mark .parametrize (
653+ "metric_value" ,
654+ [1.3 , "nan" , "inf" , "-inf" , None ],
655+ )
656+ def test_is_input_valid (under_test , metric_value ):
657+ assert under_test ._is_input_valid ("metric" , "Name" , metric_value )
658+
659+
660+ @pytest .mark .parametrize (
661+ "metric_value" ,
662+ [nan , inf , - inf ],
663+ )
664+ def test__is_input_valid_false (under_test , metric_value ):
665+ assert not under_test ._is_input_valid ("parameter" , "Name" , metric_value )
0 commit comments