@@ -765,13 +765,29 @@ def warm_up_linear(step: int, warmup_steps: int) -> float:
765765 if self .rank == 0 :
766766 self ._log_parameter_count ()
767767
768+ @staticmethod
769+ def _count_parameters (model : torch .nn .Module ) -> tuple [int , int ]:
770+ """
771+ Count model parameters.
772+
773+ Parameters
774+ ----------
775+ model : torch.nn.Module
776+ The model to count parameters for.
777+
778+ Returns
779+ -------
780+ tuple[int, int]
781+ A tuple of (trainable, total) parameter counts.
782+ """
783+ trainable = sum (p .numel () for p in model .parameters () if p .requires_grad )
784+ total = sum (p .numel () for p in model .parameters ())
785+ return trainable , total
786+
768787 def _log_parameter_count (self ) -> None :
769788 """Log model parameter count."""
770789 if not self .multi_task :
771- trainable = sum (
772- p .numel () for p in self .model .parameters () if p .requires_grad
773- )
774- total = sum (p .numel () for p in self .model .parameters ())
790+ trainable , total = self ._count_parameters (self .model )
775791 log .info (
776792 f"Model Params: { total / 1e6 :.3f} M (Trainable: { trainable / 1e6 :.3f} M)"
777793 )
@@ -781,11 +797,7 @@ def _log_parameter_count(self) -> None:
781797 "The following per-task counts may include duplicates."
782798 )
783799 for model_key in self .model_keys :
784- model = self .model [model_key ]
785- trainable = sum (
786- p .numel () for p in model .parameters () if p .requires_grad
787- )
788- total = sum (p .numel () for p in model .parameters ())
800+ trainable , total = self ._count_parameters (self .model [model_key ])
789801 log .info (
790802 f"Model Params [{ model_key } ]: { total / 1e6 :.3f} M (Trainable: { trainable / 1e6 :.3f} M)"
791803 )
0 commit comments