Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions deepmd/pt/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -761,6 +761,47 @@ def warm_up_linear(step: int, warmup_steps: int) -> float:
self.profiling = training_params.get("profiling", False)
self.profiling_file = training_params.get("profiling_file", "timeline.json")

# Log model parameter count
if self.rank == 0:
self._log_parameter_count()

@staticmethod
def _count_parameters(model: torch.nn.Module) -> tuple[int, int]:
"""
Count model parameters.

Parameters
----------
model : torch.nn.Module
The model to count parameters for.

Returns
-------
tuple[int, int]
A tuple of (trainable, total) parameter counts.
"""
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
total = sum(p.numel() for p in model.parameters())
return trainable, total

def _log_parameter_count(self) -> None:
"""Log model parameter count."""
if not self.multi_task:
trainable, total = self._count_parameters(self.model)
log.info(
f"Model Params: {total / 1e6:.3f} M (Trainable: {trainable / 1e6:.3f} M)"
)
else:
log.warning(
"In multitask mode, parameters may be shared across tasks. "
"The following per-task counts may include duplicates."
)
for model_key in self.model_keys:
trainable, total = self._count_parameters(self.model[model_key])
log.info(
f"Model Params [{model_key}]: {total / 1e6:.3f} M (Trainable: {trainable / 1e6:.3f} M)"
)

def run(self) -> None:
fout = (
open(
Expand Down