@@ -106,9 +106,70 @@ def test_apply_mask(source, ekd_from_source, mask_name, rename, threshold_option
106106 assert np .sum (np .isnan (result )) == expected_mask_count
107107
108108
109- if __name__ == "__main__" :
110- """Run all test functions that start with 'test_'."""
111- for name , obj in list (globals ().items ()):
112- if name .startswith ("test_" ) and callable (obj ):
113- print (f"Running { name } ..." )
114- obj ()
109+ @pytest .mark .parametrize (
110+ "threshold_options" ,
111+ [
112+ {"mask_value" : 0.5 },
113+ {"mask_value" : 1 },
114+ {"threshold" : 0.5 , "threshold_operator" : ">" },
115+ {"threshold" : 0.5 , "threshold_operator" : "<" },
116+ ],
117+ )
118+ @pytest .mark .parametrize ("rename" , [None , "renamed" ])
119+ @pytest .mark .parametrize ("mask_name" , MASK_VALUES .keys ())
120+ @pytest .mark .parametrize ("target_param" , DATA_VALUES .keys ())
121+ def test_apply_mask_to_param (source , ekd_from_source , target_param , mask_name , rename , threshold_options ):
122+ apply_mask = filter_registry .create (
123+ "apply_mask_to_param" , param = target_param , path = mask_name , rename = rename , ** threshold_options
124+ )
125+ ekd_from_source .assert_called_once_with ("file" , mask_name )
126+
127+ pipeline = source | apply_mask
128+
129+ input_fields = collect_fields_by_param (source )
130+ output_fields = collect_fields_by_param (pipeline )
131+
132+ expected_mask = MASK_VALUES [mask_name ].copy ().flatten ()
133+ if "mask_value" in threshold_options :
134+ expected_mask = expected_mask == threshold_options ["mask_value" ]
135+ else :
136+ operator = {"<" : np .less , ">" : np .greater }[threshold_options ["threshold_operator" ]]
137+ expected_mask = operator (expected_mask , threshold_options ["threshold" ])
138+ expected_mask_count = np .sum (expected_mask )
139+
140+ result_param = f"{ target_param } _{ rename } " if rename else target_param
141+
142+ # The target param should be masked (and possibly renamed)
143+ assert result_param in output_fields
144+ for input_field , output_field in zip (input_fields [target_param ], output_fields [result_param ]):
145+ expected_values = input_field .to_numpy (flatten = True ).copy ()
146+ expected_values [expected_mask ] = np .nan
147+ result = output_field .to_numpy (flatten = True )
148+ np .array_equal (expected_values , result , equal_nan = True )
149+ assert np .sum (np .isnan (result )) == expected_mask_count
150+
151+ # When renamed, the original param name must not appear in the output
152+ if rename :
153+ assert target_param not in output_fields
154+
155+ # All other params should pass through unchanged
156+ for other_param in DATA_VALUES .keys ():
157+ if other_param == target_param :
158+ continue
159+ assert other_param in output_fields
160+ for input_field , output_field in zip (input_fields [other_param ], output_fields [other_param ]):
161+ np .testing .assert_array_equal (
162+ input_field .to_numpy (flatten = True ),
163+ output_field .to_numpy (flatten = True ),
164+ )
165+
166+
167+ def test_apply_mask_to_param_fails_without_param (ekd_from_source ):
168+ with pytest .raises ((ValueError , TypeError )):
169+ filter_registry .create ("apply_mask_to_param" , path = "all_zeros" , mask_value = 1 )
170+
171+
172+ def test_apply_mask_to_param_fails_without_mask_options (ekd_from_source ):
173+ with pytest .raises (ValueError ):
174+ filter_registry .create ("apply_mask_to_param" , param = "t" , path = "all_zeros" )
175+ ekd_from_source .assert_called_once_with ("file" , "all_zeros" )
0 commit comments