@@ -169,6 +169,29 @@ def std_err_diff_baselines(rewards, baselines):
169169 return adjusted_reward_mean , adjusted_se
170170
171171
172+ def _clean_input (rewards , baselines ):
173+ rewards = np .asarray (rewards )
174+ baselines = np .asarray (baselines )
175+ baselines = _replace_nans_by_average (baselines )
176+ if rewards .shape [0 ] != baselines .shape [0 ]:
177+ raise ValueError ("rewards and baselines must have the same length." )
178+ if rewards .ndim != 1 :
179+ raise ValueError ("rewards must be a 1D array." )
180+ if baselines .ndim != 2 :
181+ raise ValueError ("baselines must be a 2D array." )
182+
183+ # remove nan rows
184+ valid = ~ np .isnan (rewards )
185+ rewards = rewards [valid ]
186+ baselines = baselines [valid ]
187+ if rewards .size == 0 :
188+ raise ValueError ("No valid rewards after filtering." )
189+ if baselines .shape [0 ] != rewards .shape [0 ]:
190+ raise ValueError ("rewards and baselines must have the same length after filtering." )
191+
192+ return rewards , baselines
193+
194+
172195def std_err_ancova (rewards , baselines ):
173196 """
174197 Parameters:
@@ -183,9 +206,7 @@ def std_err_ancova(rewards, baselines):
183206 - standard_error: float
184207 Standard error of the adjusted mean
185208 """
186- # Convert inputs to numpy arrays
187- rewards = np .asarray (rewards )
188- baselines = np .asarray (baselines )
209+ rewards , baselines = _clean_input (rewards , baselines )
189210
190211 # Center the baselines
191212 baseline_means = baselines .mean (axis = 0 )
0 commit comments