@@ -17,15 +17,15 @@ def __init__(
1717 self ,
1818 num_node_properties : int ,
1919 num_bond_properties : int ,
20- num_molecule_properties : int ,
21- distribution : str ,
20+ # num_molecule_properties: int,
21+ distribution : str = "normal" ,
2222 * args ,
2323 ** kwargs ,
2424 ):
2525 super ().__init__ (* args , ** kwargs )
2626 self .num_node_properties = num_node_properties
2727 self .num_bond_properties = num_bond_properties
28- self .num_molecule_properties = num_molecule_properties
28+ # self.num_molecule_properties = num_molecule_properties
2929 assert distribution in ["normal" , "uniform" , "xavier_normal" , "xavier_uniform" ]
3030 self .distribution = distribution
3131
@@ -44,30 +44,30 @@ def _read_data(self, raw_data):
4444 random_edge_attr = torch .empty (
4545 data .edge_index .shape [1 ], self .num_bond_properties
4646 )
47- random_molecule_properties = torch .empty (1 , self .num_molecule_properties )
47+ # random_molecule_properties = torch.empty(1, self.num_molecule_properties)
4848
4949 if self .distribution == "normal" :
5050 torch .nn .init .normal_ (random_x )
5151 torch .nn .init .normal_ (random_edge_attr )
52- torch .nn .init .normal_ (random_molecule_properties )
52+ # torch.nn.init.normal_(random_molecule_properties)
5353 elif self .distribution == "uniform" :
5454 torch .nn .init .uniform_ (random_x , a = - 1.0 , b = 1.0 )
5555 torch .nn .init .uniform_ (random_edge_attr , a = - 1.0 , b = 1.0 )
56- torch .nn .init .uniform_ (random_molecule_properties , a = - 1.0 , b = 1.0 )
56+ # torch.nn.init.uniform_(random_molecule_properties, a=-1.0, b=1.0)
5757 elif self .distribution == "xavier_normal" :
5858 torch .nn .init .xavier_normal_ (random_x )
5959 torch .nn .init .xavier_normal_ (random_edge_attr )
60- torch .nn .init .xavier_normal_ (random_molecule_properties )
60+ # torch.nn.init.xavier_normal_(random_molecule_properties)
6161 elif self .distribution == "xavier_uniform" :
6262 torch .nn .init .xavier_uniform_ (random_x )
6363 torch .nn .init .xavier_uniform_ (random_edge_attr )
64- torch .nn .init .xavier_uniform_ (random_molecule_properties )
64+ # torch.nn.init.xavier_uniform_(random_molecule_properties)
6565 else :
6666 raise ValueError ("Unknown distribution type" )
6767
6868 data .x = random_x
6969 data .edge_attr = random_edge_attr
70- data .molecule_attr = random_molecule_properties
70+ # data.molecule_attr = random_molecule_properties
7171 return data
7272
7373 def read_property (self , * args , ** kwargs ) -> Exception :
0 commit comments