File tree Expand file tree Collapse file tree 1 file changed +27
-0
lines changed Expand file tree Collapse file tree 1 file changed +27
-0
lines changed Original file line number Diff line number Diff line change 20
20
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
21
# SOFTWARE.
22
22
23
+ import collections
23
24
import json
24
25
import logging
25
26
import os
@@ -724,3 +725,29 @@ def push_to_tensorboard( # noqa: C901
724
725
f"Pushed to tensorboard at https://huggingface.co/{ self .tensorboard_repo } /{ output_dir_tb } /tensorboard"
725
726
f" at global_step { global_step } "
726
727
)
728
+
729
+ def calculate_num_samples (self ) -> dict [str , int ]:
730
+ """
731
+ Counts the number of samples per task, includes grouped tasks.
732
+ This implementation is oriented on MetricsLogger.aggregate(), to make sure the subgroups of tasks match up.
733
+ """
734
+
735
+ # Count samples of individual tasks
736
+ num_samples = {task : len (samples ) for task , samples in self .details_logger .details .items ()}
737
+
738
+ # Count samples for sub groups
739
+ grouped_tasks = collections .defaultdict (list )
740
+
741
+ for task in num_samples :
742
+ if "|" in task :
743
+ suite , task , fewshot = task .split ("|" )
744
+ grouped_tasks [f"{ suite } |{ task .split (':' )[0 ]} :_average|{ fewshot } " ].append (task )
745
+
746
+ for average_task , list_of_subtasks in grouped_tasks .items ():
747
+ if len (list_of_subtasks ) > 1 :
748
+ num_samples [average_task ] = sum (num_samples [k ] for k in list_of_subtasks )
749
+
750
+ # Add sample count for all
751
+ num_samples ["all" ] = sum (count for task , count in num_samples .items () if task != "all" )
752
+
753
+ return num_samples
You can’t perform that action at this time.
0 commit comments