Skip to content

Commit d2064d2

Browse files
committed
Create calculate_num_samples method in evaluation_tracker to count number of samples per task
1 parent 88bd36a commit d2064d2

File tree

1 file changed

+27
-0
lines changed

1 file changed

+27
-0
lines changed

src/lighteval/logging/evaluation_tracker.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
2121
# SOFTWARE.
2222

23+
import collections
2324
import json
2425
import logging
2526
import os
@@ -724,3 +725,29 @@ def push_to_tensorboard( # noqa: C901
724725
f"Pushed to tensorboard at https://huggingface.co/{self.tensorboard_repo}/{output_dir_tb}/tensorboard"
725726
f" at global_step {global_step}"
726727
)
728+
729+
def calculate_num_samples(self) -> dict[str, int]:
730+
"""
731+
Counts the number of samples per task, includes grouped tasks.
732+
This implementation is oriented on MetricsLogger.aggregate(), to make sure the subgroups of tasks match up.
733+
"""
734+
735+
# Count samples of individual tasks
736+
num_samples = {task: len(samples) for task, samples in self.details_logger.details.items()}
737+
738+
# Count samples for sub groups
739+
grouped_tasks = collections.defaultdict(list)
740+
741+
for task in num_samples:
742+
if "|" in task:
743+
suite, task, fewshot = task.split("|")
744+
grouped_tasks[f"{suite}|{task.split(':')[0]}:_average|{fewshot}"].append(task)
745+
746+
for average_task, list_of_subtasks in grouped_tasks.items():
747+
if len(list_of_subtasks) > 1:
748+
num_samples[average_task] = sum(num_samples[k] for k in list_of_subtasks)
749+
750+
# Add sample count for all
751+
num_samples["all"] = sum(count for task, count in num_samples.items() if task != "all")
752+
753+
return num_samples

0 commit comments

Comments
 (0)