Skip to content

Commit 7482772

Browse files
authored
adding sample std if non bernouilli rewards (#56)
* adding sample std if non bernouilli rewards * switching to ddof=1
1 parent adfcb8c commit 7482772

File tree

1 file changed

+16
-2
lines changed

1 file changed

+16
-2
lines changed

src/agentlab/analyze/inspect_results.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -242,9 +242,23 @@ def get_std_err(df, metric):
242242
data = df[metric].dropna().values
243243

244244
# asser either 0 or 1
245-
assert np.all(np.isin(data, [0, 1]))
245+
if np.all(np.isin(data, [0, 1])):
246+
mean = np.mean(data)
247+
std_err = np.sqrt(mean * (1 - mean) / len(data))
248+
else:
249+
return get_sample_std_err(df, metric)
250+
return mean, std_err
251+
252+
253+
def get_sample_std_err(df, metric):
254+
"""Get the standard error for a binary metric."""
255+
# extract non missing values
256+
data = df[metric].dropna().values
257+
246258
mean = np.mean(data)
247-
std_err = np.sqrt(mean * (1 - mean) / len(data))
259+
std_err = np.std(data, ddof=1) / np.sqrt(len(data))
260+
if np.isnan(std_err):
261+
std_err = 0
248262
return mean, std_err
249263

250264

0 commit comments

Comments
 (0)