@@ -17,15 +17,15 @@ def __init__(
1717 self ,
1818 num_node_properties : int ,
1919 num_bond_properties : int ,
20- # num_molecule_properties: int,
20+ num_molecule_properties : int ,
2121 distribution : str = "normal" ,
2222 * args ,
2323 ** kwargs ,
2424 ):
2525 super ().__init__ (* args , ** kwargs )
2626 self .num_node_properties = num_node_properties
2727 self .num_bond_properties = num_bond_properties
28- # self.num_molecule_properties = num_molecule_properties
28+ self .num_molecule_properties = num_molecule_properties
2929 assert distribution in ["normal" , "uniform" , "xavier_normal" , "xavier_uniform" ]
3030 self .distribution = distribution
3131
@@ -44,30 +44,30 @@ def _read_data(self, raw_data):
4444 random_edge_attr = torch .empty (
4545 data .edge_index .shape [1 ], self .num_bond_properties
4646 )
47- # random_molecule_properties = torch.empty(1, self.num_molecule_properties)
47+ random_molecule_properties = torch .empty (1 , self .num_molecule_properties )
4848
4949 if self .distribution == "normal" :
5050 torch .nn .init .normal_ (random_x )
5151 torch .nn .init .normal_ (random_edge_attr )
52- # torch.nn.init.normal_(random_molecule_properties)
52+ torch .nn .init .normal_ (random_molecule_properties )
5353 elif self .distribution == "uniform" :
5454 torch .nn .init .uniform_ (random_x , a = - 1.0 , b = 1.0 )
5555 torch .nn .init .uniform_ (random_edge_attr , a = - 1.0 , b = 1.0 )
56- # torch.nn.init.uniform_(random_molecule_properties, a=-1.0, b=1.0)
56+ torch .nn .init .uniform_ (random_molecule_properties , a = - 1.0 , b = 1.0 )
5757 elif self .distribution == "xavier_normal" :
5858 torch .nn .init .xavier_normal_ (random_x )
5959 torch .nn .init .xavier_normal_ (random_edge_attr )
60- # torch.nn.init.xavier_normal_(random_molecule_properties)
60+ torch .nn .init .xavier_normal_ (random_molecule_properties )
6161 elif self .distribution == "xavier_uniform" :
6262 torch .nn .init .xavier_uniform_ (random_x )
6363 torch .nn .init .xavier_uniform_ (random_edge_attr )
64- # torch.nn.init.xavier_uniform_(random_molecule_properties)
64+ torch .nn .init .xavier_uniform_ (random_molecule_properties )
6565 else :
6666 raise ValueError ("Unknown distribution type" )
6767
6868 data .x = random_x
6969 data .edge_attr = random_edge_attr
70- # data.molecule_attr = random_molecule_properties
70+ data .molecule_attr = random_molecule_properties
7171 return data
7272
7373 def read_property (self , * args , ** kwargs ) -> Exception :
0 commit comments