File tree Expand file tree Collapse file tree 1 file changed +27
-0
lines changed Expand file tree Collapse file tree 1 file changed +27
-0
lines changed Original file line number Diff line number Diff line change 2020# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
2121# SOFTWARE.
2222
23+ import collections
2324import json
2425import logging
2526import os
@@ -724,3 +725,29 @@ def push_to_tensorboard( # noqa: C901
724725 f"Pushed to tensorboard at https://huggingface.co/{ self .tensorboard_repo } /{ output_dir_tb } /tensorboard"
725726 f" at global_step { global_step } "
726727 )
728+
729+ def calculate_num_samples (self ) -> dict [str , int ]:
730+ """
731+ Counts the number of samples per task, includes grouped tasks.
732+ This implementation is oriented on MetricsLogger.aggregate(), to make sure the subgroups of tasks match up.
733+ """
734+
735+ # Count samples of individual tasks
736+ num_samples = {task : len (samples ) for task , samples in self .details_logger .details .items ()}
737+
738+ # Count samples for sub groups
739+ grouped_tasks = collections .defaultdict (list )
740+
741+ for task in num_samples :
742+ if "|" in task :
743+ suite , task , fewshot = task .split ("|" )
744+ grouped_tasks [f"{ suite } |{ task .split (':' )[0 ]} :_average|{ fewshot } " ].append (task )
745+
746+ for average_task , list_of_subtasks in grouped_tasks .items ():
747+ if len (list_of_subtasks ) > 1 :
748+ num_samples [average_task ] = sum (num_samples [k ] for k in list_of_subtasks )
749+
750+ # Add sample count for all
751+ num_samples ["all" ] = sum (count for task , count in num_samples .items () if task != "all" )
752+
753+ return num_samples
You can’t perform that action at this time.
0 commit comments