Skip to content

Commit a0d6ea7

Browse files
committed
revert mol props
1 parent 102a71c commit a0d6ea7

File tree

2 files changed

+9
-8
lines changed

2 files changed

+9
-8
lines changed

chebai_graph/preprocessing/reader/static_gni.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,15 @@ def __init__(
1717
self,
1818
num_node_properties: int,
1919
num_bond_properties: int,
20-
# num_molecule_properties: int,
20+
num_molecule_properties: int,
2121
distribution: str = "normal",
2222
*args,
2323
**kwargs,
2424
):
2525
super().__init__(*args, **kwargs)
2626
self.num_node_properties = num_node_properties
2727
self.num_bond_properties = num_bond_properties
28-
# self.num_molecule_properties = num_molecule_properties
28+
self.num_molecule_properties = num_molecule_properties
2929
assert distribution in ["normal", "uniform", "xavier_normal", "xavier_uniform"]
3030
self.distribution = distribution
3131

@@ -44,30 +44,30 @@ def _read_data(self, raw_data):
4444
random_edge_attr = torch.empty(
4545
data.edge_index.shape[1], self.num_bond_properties
4646
)
47-
# random_molecule_properties = torch.empty(1, self.num_molecule_properties)
47+
random_molecule_properties = torch.empty(1, self.num_molecule_properties)
4848

4949
if self.distribution == "normal":
5050
torch.nn.init.normal_(random_x)
5151
torch.nn.init.normal_(random_edge_attr)
52-
# torch.nn.init.normal_(random_molecule_properties)
52+
torch.nn.init.normal_(random_molecule_properties)
5353
elif self.distribution == "uniform":
5454
torch.nn.init.uniform_(random_x, a=-1.0, b=1.0)
5555
torch.nn.init.uniform_(random_edge_attr, a=-1.0, b=1.0)
56-
# torch.nn.init.uniform_(random_molecule_properties, a=-1.0, b=1.0)
56+
torch.nn.init.uniform_(random_molecule_properties, a=-1.0, b=1.0)
5757
elif self.distribution == "xavier_normal":
5858
torch.nn.init.xavier_normal_(random_x)
5959
torch.nn.init.xavier_normal_(random_edge_attr)
60-
# torch.nn.init.xavier_normal_(random_molecule_properties)
60+
torch.nn.init.xavier_normal_(random_molecule_properties)
6161
elif self.distribution == "xavier_uniform":
6262
torch.nn.init.xavier_uniform_(random_x)
6363
torch.nn.init.xavier_uniform_(random_edge_attr)
64-
# torch.nn.init.xavier_uniform_(random_molecule_properties)
64+
torch.nn.init.xavier_uniform_(random_molecule_properties)
6565
else:
6666
raise ValueError("Unknown distribution type")
6767

6868
data.x = random_x
6969
data.edge_attr = random_edge_attr
70-
# data.molecule_attr = random_molecule_properties
70+
data.molecule_attr = random_molecule_properties
7171
return data
7272

7373
def read_property(self, *args, **kwargs) -> Exception:

configs/data/chebi50_static_gni.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@ init_args:
33
reader_kwargs:
44
num_node_properties: 158
55
num_bond_properties: 7
6+
num_molecule_properties: 0

0 commit comments

Comments
 (0)