diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index 0dfbe94b6b..520a628756 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -761,6 +761,47 @@ def warm_up_linear(step: int, warmup_steps: int) -> float: self.profiling = training_params.get("profiling", False) self.profiling_file = training_params.get("profiling_file", "timeline.json") + # Log model parameter count + if self.rank == 0: + self._log_parameter_count() + + @staticmethod + def _count_parameters(model: torch.nn.Module) -> tuple[int, int]: + """ + Count model parameters. + + Parameters + ---------- + model : torch.nn.Module + The model to count parameters for. + + Returns + ------- + tuple[int, int] + A tuple of (trainable, total) parameter counts. + """ + trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) + total = sum(p.numel() for p in model.parameters()) + return trainable, total + + def _log_parameter_count(self) -> None: + """Log model parameter count.""" + if not self.multi_task: + trainable, total = self._count_parameters(self.model) + log.info( + f"Model Params: {total / 1e6:.3f} M (Trainable: {trainable / 1e6:.3f} M)" + ) + else: + log.warning( + "In multitask mode, parameters may be shared across tasks. " + "The following per-task counts may include duplicates." + ) + for model_key in self.model_keys: + trainable, total = self._count_parameters(self.model[model_key]) + log.info( + f"Model Params [{model_key}]: {total / 1e6:.3f} M (Trainable: {trainable / 1e6:.3f} M)" + ) + def run(self) -> None: fout = ( open(