Skip to content

Commit 28f0d7a

Browse files
committed
merge
1 parent 9ab02e4 commit 28f0d7a

File tree

1 file changed

+21
-9
lines changed

1 file changed

+21
-9
lines changed

deepmd/pt/train/training.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -765,13 +765,29 @@ def warm_up_linear(step: int, warmup_steps: int) -> float:
765765
if self.rank == 0:
766766
self._log_parameter_count()
767767

768+
@staticmethod
769+
def _count_parameters(model: torch.nn.Module) -> tuple[int, int]:
770+
"""
771+
Count model parameters.
772+
773+
Parameters
774+
----------
775+
model : torch.nn.Module
776+
The model to count parameters for.
777+
778+
Returns
779+
-------
780+
tuple[int, int]
781+
A tuple of (trainable, total) parameter counts.
782+
"""
783+
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
784+
total = sum(p.numel() for p in model.parameters())
785+
return trainable, total
786+
768787
def _log_parameter_count(self) -> None:
769788
"""Log model parameter count."""
770789
if not self.multi_task:
771-
trainable = sum(
772-
p.numel() for p in self.model.parameters() if p.requires_grad
773-
)
774-
total = sum(p.numel() for p in self.model.parameters())
790+
trainable, total = self._count_parameters(self.model)
775791
log.info(
776792
f"Model Params: {total / 1e6:.3f} M (Trainable: {trainable / 1e6:.3f} M)"
777793
)
@@ -781,11 +797,7 @@ def _log_parameter_count(self) -> None:
781797
"The following per-task counts may include duplicates."
782798
)
783799
for model_key in self.model_keys:
784-
model = self.model[model_key]
785-
trainable = sum(
786-
p.numel() for p in model.parameters() if p.requires_grad
787-
)
788-
total = sum(p.numel() for p in model.parameters())
800+
trainable, total = self._count_parameters(self.model[model_key])
789801
log.info(
790802
f"Model Params [{model_key}]: {total / 1e6:.3f} M (Trainable: {trainable / 1e6:.3f} M)"
791803
)

0 commit comments

Comments
 (0)