Skip to content

Commit 2cd5a5e

Browse files
fix + tests + doc files
1 parent f13a5f3 commit 2cd5a5e

20 files changed

+862
-350
lines changed

docs/source/_rst/_code.rst

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,18 @@ Blocks
122122
Continuous Convolution Block <model/block/convolution.rst>
123123
Orthogonal Block <model/block/orthogonal.rst>
124124

125+
Message Passing
126+
-------------------
127+
128+
.. toctree::
129+
:titlesonly:
130+
131+
Deep Tensor Network Block <model/block/message_passing/deep_tensor_network_block.rst>
132+
E(n) Equivariant Network Block <model/block/message_passing/en_equivariant_network_block.rst>
133+
Interaction Network Block <model/block/message_passing/interaction_network_block.rst>
134+
Radial Field Network Block <model/block/message_passing/radial_field_network_block.rst>
135+
Schnet Block <model/block/message_passing/schnet_block.rst>
136+
125137

126138
Reduction and Embeddings
127139
--------------------------
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
Deep Tensor Network Block
2+
==================================
3+
.. currentmodule:: pina.model.block.message_passing.deep_tensor_network_block
4+
5+
.. autoclass:: DeepTensorNetworkBlock
6+
:members:
7+
:show-inheritance:
8+
:noindex:
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
E(n) Equivariant Network Block
2+
==================================
3+
.. currentmodule:: pina.model.block.message_passing.en_equivariant_network_block
4+
5+
.. autoclass:: EnEquivariantNetworkBlock
6+
:members:
7+
:show-inheritance:
8+
:noindex:
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
Interaction Network Block
2+
==================================
3+
.. currentmodule:: pina.model.block.message_passing.interaction_network_block
4+
5+
.. autoclass:: InteractionNetworkBlock
6+
:members:
7+
:show-inheritance:
8+
:noindex:
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
Radial Field Network Block
2+
==================================
3+
.. currentmodule:: pina.model.block.message_passing.radial_field_network_block
4+
5+
.. autoclass:: RadialFieldNetworkBlock
6+
:members:
7+
:show-inheritance:
8+
:noindex:
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
Schnet Block
2+
==================================
3+
.. currentmodule:: pina.model.block.message_passing.schnet_block
4+
5+
.. autoclass:: SchnetBlock
6+
:members:
7+
:show-inheritance:
8+
:noindex:

pina/model/block/message_passing/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,13 @@
33
__all__ = [
44
"InteractionNetworkBlock",
55
"DeepTensorNetworkBlock",
6+
"EnEquivariantNetworkBlock",
7+
"RadialFieldNetworkBlock",
8+
"SchnetBlock",
69
]
710

811
from .interaction_network_block import InteractionNetworkBlock
912
from .deep_tensor_network_block import DeepTensorNetworkBlock
13+
from .en_equivariant_network_block import EnEquivariantNetworkBlock
14+
from .radial_field_network_block import RadialFieldNetworkBlock
15+
from .schnet_block import SchnetBlock

pina/model/block/message_passing/deep_tensor_network_block.py

Lines changed: 20 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,17 @@
22

33
import torch
44
from torch_geometric.nn import MessagePassing
5-
from ....utils import check_consistency
5+
from ....utils import check_positive_integer
66

77

88
class DeepTensorNetworkBlock(MessagePassing):
99
"""
1010
Implementation of the Deep Tensor Network block.
1111
1212
This block is used to perform message-passing between nodes and edges in a
13-
graph neural network, following the scheme proposed by Schutt et al. (2017).
14-
It serves as an inner block in a larger graph neural network architecture.
13+
graph neural network, following the scheme proposed by Schutt et al. in
14+
2017. It serves as an inner block in a larger graph neural network
15+
architecture.
1516
1617
The message between two nodes connected by an edge is computed by applying a
1718
linear transformation to the sender node features and the edge features,
@@ -24,9 +25,9 @@ class DeepTensorNetworkBlock(MessagePassing):
2425
.. seealso::
2526
2627
**Original reference**: Schutt, K., Arbabzadah, F., Chmiela, S. et al.
27-
*Quantum-Chemical Insights from Deep Tensor Neural Networks*.
28+
(2017). *Quantum-Chemical Insights from Deep Tensor Neural Networks*.
2829
Nature Communications 8, 13890 (2017).
29-
DOI: `<https://doi.org/10.1038/ncomms13890>_`.
30+
DOI: `<https://doi.org/10.1038/ncomms13890>`_.
3031
"""
3132

3233
def __init__(
@@ -39,7 +40,7 @@ def __init__(
3940
flow="source_to_target",
4041
):
4142
"""
42-
Initialization of the :class:`DeepTensorNetworkBlocklock` class.
43+
Initialization of the :class:`DeepTensorNetworkBlock` class.
4344
4445
:param int node_feature_dim: The dimension of the node features.
4546
:param int edge_feature_dim: The dimension of the edge features.
@@ -57,51 +58,36 @@ def __init__(
5758
flow means that messages are sent from the target node to the
5859
source node. See :class:`torch_geometric.nn.MessagePassing` for more
5960
details. Default is "source_to_target".
60-
:raises ValueError: If `node_feature_dim` is not a positive integer.
61-
:raises ValueError: If `edge_feature_dim` is not a positive integer.
61+
:raises AssertionError: If `node_feature_dim` is not a positive integer.
62+
:raises AssertionError: If `edge_feature_dim` is not a positive integer.
6263
"""
6364
super().__init__(aggr=aggr, node_dim=node_dim, flow=flow)
6465

65-
# Check consistency
66-
check_consistency(node_feature_dim, int)
67-
check_consistency(edge_feature_dim, int)
68-
6966
# Check values
70-
if node_feature_dim <= 0:
71-
raise ValueError(
72-
"`node_feature_dim` must be a positive integer,"
73-
f" got {node_feature_dim}."
74-
)
75-
76-
if edge_feature_dim <= 0:
77-
raise ValueError(
78-
"`edge_feature_dim` must be a positive integer,"
79-
f" got {edge_feature_dim}."
80-
)
81-
82-
# Initialize parameters
83-
self.node_feature_dim = node_feature_dim
84-
self.edge_feature_dim = edge_feature_dim
85-
self.activation = activation
67+
check_positive_integer(node_feature_dim, strict=True)
68+
check_positive_integer(edge_feature_dim, strict=True)
69+
70+
# Activation function
71+
self.activation = activation()
8672

8773
# Layer for processing node features
8874
self.node_layer = torch.nn.Linear(
89-
in_features=self.node_feature_dim,
90-
out_features=self.node_feature_dim,
75+
in_features=node_feature_dim,
76+
out_features=node_feature_dim,
9177
bias=True,
9278
)
9379

9480
# Layer for processing edge features
9581
self.edge_layer = torch.nn.Linear(
96-
in_features=self.edge_feature_dim,
97-
out_features=self.node_feature_dim,
82+
in_features=edge_feature_dim,
83+
out_features=node_feature_dim,
9884
bias=True,
9985
)
10086

10187
# Layer for computing the message
10288
self.message_layer = torch.nn.Linear(
103-
in_features=self.node_feature_dim,
104-
out_features=self.node_feature_dim,
89+
in_features=node_feature_dim,
90+
out_features=node_feature_dim,
10591
bias=False,
10692
)
10793

pina/model/block/message_passing/egnn_block.py

Lines changed: 0 additions & 137 deletions
This file was deleted.

0 commit comments

Comments
 (0)