-
Notifications
You must be signed in to change notification settings - Fork 586
feat(pt): add descriptor name & parameter numbers output & gpu name (only for cuda) & Capitalise some infos (all backends) #5141
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from all commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
d6fa9cb
feat(pt): add descriptor name and paramter numbers output
OutisLi bc341f7
add device name display
OutisLi 50f90b4
fix: correct device display only for cuda
OutisLi d9c3db7
feat: Capitalise some infos
OutisLi cc85284
Update deepmd/utils/summary.py
OutisLi 218cdc9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -721,6 +721,51 @@ def warm_up_linear(step: int, warmup_steps: int) -> float: | |
| self.profiling = training_params.get("profiling", False) | ||
| self.profiling_file = training_params.get("profiling_file", "timeline.json") | ||
|
|
||
| # Log model summary info (descriptor type and parameter count) | ||
| if self.rank == 0: | ||
| self._log_model_summary() | ||
|
|
||
| def _log_model_summary(self) -> None: | ||
| """Log model summary information including descriptor type and parameter count.""" | ||
|
|
||
| def get_descriptor_type(model: Any) -> str: | ||
| """Get the descriptor type name from model.""" | ||
| # Standard models have get_descriptor method | ||
| if hasattr(model, "get_descriptor"): | ||
| descriptor = model.get_descriptor() | ||
| serialized = descriptor.serialize() | ||
| if isinstance(serialized, dict) and "type" in serialized: | ||
| return serialized["type"].upper() | ||
| # ZBL models: descriptor is in atomic_model.models[0] | ||
| if hasattr(model, "atomic_model") and hasattr(model.atomic_model, "models"): | ||
| models = model.atomic_model.models | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's not a good behavior to visit a inner attribution like this. |
||
| if models: # Check non-empty | ||
| dp_model = models[0] | ||
| if hasattr(dp_model, "descriptor"): | ||
| serialized = dp_model.descriptor.serialize() | ||
| if isinstance(serialized, dict) and "type" in serialized: | ||
| return serialized["type"].upper() + " (with ZBL)" | ||
OutisLi marked this conversation as resolved.
Show resolved
Hide resolved
OutisLi marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| return "UNKNOWN" | ||
|
|
||
| def count_parameters(model: Any) -> int: | ||
| """Count the total number of trainable parameters.""" | ||
| return sum(p.numel() for p in model.parameters() if p.requires_grad) | ||
OutisLi marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| if not self.multi_task: | ||
| desc_type = get_descriptor_type(self.model) | ||
| num_params = count_parameters(self.model) | ||
| log.info("") | ||
| log.info(f"Descriptor: {desc_type}") | ||
| log.info(f"Model Params: {num_params / 1e6:.3f} M") | ||
| else: | ||
| # For multi-task, log each model's info | ||
| for model_key in self.model_keys: | ||
| desc_type = get_descriptor_type(self.model[model_key]) | ||
| num_params = count_parameters(self.model[model_key]) | ||
| log.info("") | ||
| log.info(f"Descriptor [{model_key}]: {desc_type}") | ||
OutisLi marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| log.info(f"Model Params [{model_key}]: {num_params / 1e6:.3f} M") | ||
OutisLi marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| def run(self) -> None: | ||
| fout = ( | ||
| open( | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,144 @@ | ||
| # SPDX-License-Identifier: LGPL-3.0-or-later | ||
| """Tests for model summary display functions.""" | ||
|
|
||
| import unittest | ||
| from unittest.mock import ( | ||
| MagicMock, | ||
| ) | ||
|
|
||
| import torch | ||
|
|
||
|
|
||
| class TestGetDescriptorType(unittest.TestCase): | ||
| """Test get_descriptor_type helper function.""" | ||
|
|
||
| @staticmethod | ||
| def get_descriptor_type(model): | ||
| """Replicate the logic from training.py for testing.""" | ||
| # Standard models have get_descriptor method | ||
| if hasattr(model, "get_descriptor"): | ||
| descriptor = model.get_descriptor() | ||
| serialized = descriptor.serialize() | ||
| if isinstance(serialized, dict) and "type" in serialized: | ||
| return serialized["type"].upper() | ||
| # ZBL models: descriptor is in atomic_model.models[0] | ||
| if hasattr(model, "atomic_model") and hasattr(model.atomic_model, "models"): | ||
| models = model.atomic_model.models | ||
| if models: # Check non-empty | ||
| dp_model = models[0] | ||
| if hasattr(dp_model, "descriptor"): | ||
| serialized = dp_model.descriptor.serialize() | ||
| if isinstance(serialized, dict) and "type" in serialized: | ||
| return serialized["type"].upper() + " (with ZBL)" | ||
| return "UNKNOWN" | ||
OutisLi marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| def test_standard_model(self): | ||
| """Test descriptor type detection for standard models.""" | ||
| mock_descriptor = MagicMock() | ||
| mock_descriptor.serialize.return_value = {"type": "se_e2_a"} | ||
|
|
||
| mock_model = MagicMock() | ||
| mock_model.get_descriptor.return_value = mock_descriptor | ||
|
|
||
| result = self.get_descriptor_type(mock_model) | ||
| self.assertEqual(result, "SE_E2_A") | ||
|
|
||
| def test_zbl_model(self): | ||
| """Test descriptor type detection for ZBL models.""" | ||
| mock_descriptor = MagicMock() | ||
| mock_descriptor.serialize.return_value = {"type": "dpa1"} | ||
|
|
||
| mock_dp_model = MagicMock() | ||
| mock_dp_model.descriptor = mock_descriptor | ||
|
|
||
| mock_atomic_model = MagicMock() | ||
| mock_atomic_model.models = [mock_dp_model] | ||
|
|
||
| mock_model = MagicMock(spec=[]) # No get_descriptor | ||
| mock_model.atomic_model = mock_atomic_model | ||
|
|
||
| result = self.get_descriptor_type(mock_model) | ||
| self.assertEqual(result, "DPA1 (with ZBL)") | ||
|
|
||
| def test_empty_models_list(self): | ||
| """Test handling of empty models list in ZBL model.""" | ||
| mock_atomic_model = MagicMock() | ||
| mock_atomic_model.models = [] | ||
|
|
||
| mock_model = MagicMock(spec=[]) | ||
| mock_model.atomic_model = mock_atomic_model | ||
|
|
||
| result = self.get_descriptor_type(mock_model) | ||
| self.assertEqual(result, "UNKNOWN") | ||
|
|
||
| def test_missing_type_key(self): | ||
| """Test handling of serialize() without 'type' key.""" | ||
| mock_descriptor = MagicMock() | ||
| mock_descriptor.serialize.return_value = {"other_key": "value"} | ||
|
|
||
| mock_model = MagicMock() | ||
| mock_model.get_descriptor.return_value = mock_descriptor | ||
|
|
||
| result = self.get_descriptor_type(mock_model) | ||
| self.assertEqual(result, "UNKNOWN") | ||
|
|
||
| def test_serialize_returns_non_dict(self): | ||
| """Test handling of serialize() returning non-dict.""" | ||
| mock_descriptor = MagicMock() | ||
| mock_descriptor.serialize.return_value = "not_a_dict" | ||
|
|
||
| mock_model = MagicMock() | ||
| mock_model.get_descriptor.return_value = mock_descriptor | ||
|
|
||
| result = self.get_descriptor_type(mock_model) | ||
| self.assertEqual(result, "UNKNOWN") | ||
|
|
||
| def test_unknown_model_structure(self): | ||
| """Test handling of unknown model structure.""" | ||
| mock_model = MagicMock(spec=[]) # No get_descriptor, no atomic_model | ||
| result = self.get_descriptor_type(mock_model) | ||
| self.assertEqual(result, "UNKNOWN") | ||
|
|
||
|
|
||
| class TestCountParameters(unittest.TestCase): | ||
| """Test count_parameters helper function.""" | ||
|
|
||
| @staticmethod | ||
| def count_parameters(model): | ||
| """Replicate the logic from training.py for testing.""" | ||
| return sum(p.numel() for p in model.parameters() if p.requires_grad) | ||
OutisLi marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| def test_all_trainable(self): | ||
| """Test counting when all parameters are trainable.""" | ||
| with torch.device("cpu"): | ||
| model = torch.nn.Linear(10, 5) # 10*5 + 5 = 55 parameters | ||
| result = self.count_parameters(model) | ||
| self.assertEqual(result, 55) | ||
|
|
||
| def test_mixed_trainable(self): | ||
| """Test counting with some frozen parameters.""" | ||
| with torch.device("cpu"): | ||
| model = torch.nn.Sequential( | ||
| torch.nn.Linear(10, 5), # 55 params | ||
| torch.nn.Linear(5, 3), # 18 params | ||
| ) | ||
| # Freeze first layer | ||
| for param in model[0].parameters(): | ||
| param.requires_grad = False | ||
|
|
||
| result = self.count_parameters(model) | ||
| self.assertEqual(result, 18) # Only second layer | ||
|
|
||
| def test_all_frozen(self): | ||
| """Test counting when all parameters are frozen.""" | ||
| with torch.device("cpu"): | ||
| model = torch.nn.Linear(10, 5) | ||
| for param in model.parameters(): | ||
| param.requires_grad = False | ||
|
|
||
| result = self.count_parameters(model) | ||
| self.assertEqual(result, 0) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| unittest.main() | ||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why annotate `Any‘?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
since
returns Any