Skip to content

Commit 93b31dc

Browse files
committed
Add input validation function for rewards and baselines
1 parent 847e095 commit 93b31dc

File tree

1 file changed

+24
-3
lines changed

1 file changed

+24
-3
lines changed

src/agentlab/analyze/covariate_std_err.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,29 @@ def std_err_diff_baselines(rewards, baselines):
169169
return adjusted_reward_mean, adjusted_se
170170

171171

172+
def _clean_input(rewards, baselines):
173+
rewards = np.asarray(rewards)
174+
baselines = np.asarray(baselines)
175+
baselines = _replace_nans_by_average(baselines)
176+
if rewards.shape[0] != baselines.shape[0]:
177+
raise ValueError("rewards and baselines must have the same length.")
178+
if rewards.ndim != 1:
179+
raise ValueError("rewards must be a 1D array.")
180+
if baselines.ndim != 2:
181+
raise ValueError("baselines must be a 2D array.")
182+
183+
# remove nan rows
184+
valid = ~np.isnan(rewards)
185+
rewards = rewards[valid]
186+
baselines = baselines[valid]
187+
if rewards.size == 0:
188+
raise ValueError("No valid rewards after filtering.")
189+
if baselines.shape[0] != rewards.shape[0]:
190+
raise ValueError("rewards and baselines must have the same length after filtering.")
191+
192+
return rewards, baselines
193+
194+
172195
def std_err_ancova(rewards, baselines):
173196
"""
174197
Parameters:
@@ -183,9 +206,7 @@ def std_err_ancova(rewards, baselines):
183206
- standard_error: float
184207
Standard error of the adjusted mean
185208
"""
186-
# Convert inputs to numpy arrays
187-
rewards = np.asarray(rewards)
188-
baselines = np.asarray(baselines)
209+
rewards, baselines = _clean_input(rewards, baselines)
189210

190211
# Center the baselines
191212
baseline_means = baselines.mean(axis=0)

0 commit comments

Comments
 (0)