Skip to content

Commit 80f9df8

Browse files
committed
Add num_samples to the markdown table printed by make_results_table()
1 parent dbaad99 commit 80f9df8

File tree

1 file changed

+8
-3
lines changed

1 file changed

+8
-3
lines changed

src/lighteval/utils/utils.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -158,24 +158,29 @@ def flatten(item: list[Union[list, str]]) -> list[str]:
158158
def make_results_table(result_dict):
159159
"""Generate table of results."""
160160
md_writer = MarkdownTableWriter()
161-
md_writer.headers = ["Task", "Version", "Metric", "Value", "", "Stderr"]
161+
md_writer.headers = ["Task", "Version", "Number of Samples", "Metric", "Value", "", "Stderr"]
162162

163163
values = []
164164

165+
# For backwards compatibility, create empty dict if result_dict doesn't contain num_samples
166+
num_samples_dict = result_dict["num_samples"] if "num_samples" in result_dict else {}
167+
165168
for k in sorted(result_dict["results"].keys()):
166169
dic = result_dict["results"][k]
167170
version = result_dict["versions"][k] if k in result_dict["versions"] else ""
171+
num_samples = num_samples_dict[k] if k in num_samples_dict else ""
168172
for m, v in dic.items():
169173
if m.endswith("_stderr"):
170174
continue
171175

172176
if m + "_stderr" in dic:
173177
se = dic[m + "_stderr"]
174-
values.append([k, version, m, "%.4f" % v, "±", "%.4f" % se])
178+
values.append([k, version, num_samples, m, "%.4f" % v, "±", "%.4f" % se])
175179
else:
176-
values.append([k, version, m, "%.4f" % v, "", ""])
180+
values.append([k, version, num_samples, m, "%.4f" % v, "", ""])
177181
k = ""
178182
version = ""
183+
num_samples = ""
179184
md_writer.value_matrix = values
180185

181186
return md_writer.dumps()

0 commit comments

Comments
 (0)