1- from chebifier .prediction_models .nn_predictor import NNPredictor
21import chebai_graph .preprocessing .properties as p
32import torch
43from chebai_graph .models .graph import ResGatedGraphConvNetGraphPred
5- from chebai_graph .preprocessing .reader import GraphPropertyReader
64from chebai_graph .preprocessing .property_encoder import IndexEncoder , OneHotEncoder
5+ from chebai_graph .preprocessing .reader import GraphPropertyReader
76from torch_geometric .data .data import Data as GeomData
87
8+ from chebifier .prediction_models .nn_predictor import NNPredictor
9+
910
1011class ResGatedPredictor (NNPredictor ):
1112
1213 def __init__ (self , model_name : str , ckpt_path : str , molecular_properties , ** kwargs ):
13- super ().__init__ (model_name , ckpt_path , reader_cls = GraphPropertyReader , ** kwargs )
14+ super ().__init__ (
15+ model_name , ckpt_path , reader_cls = GraphPropertyReader , ** kwargs
16+ )
1417 # molecular_properties is a list of class paths
1518 if molecular_properties is not None :
1619 properties = [self .load_class (prop )() for prop in molecular_properties ]
@@ -32,11 +35,23 @@ def load_class(self, class_path: str):
3235
3336 def init_model (self , ckpt_path : str , ** kwargs ) -> ResGatedGraphConvNetGraphPred :
3437 model = ResGatedGraphConvNetGraphPred .load_from_checkpoint (
35- ckpt_path , map_location = torch .device (self .device ), criterion = None , strict = False ,
36- metrics = dict (train = dict (), test = dict (), validation = dict ()), pretrained_checkpoint = None ,
37- config = {"in_length" : 256 , "hidden_length" : 512 , "dropout_rate" : 0.1 , "n_conv_layers" : 3 ,
38- "n_linear_layers" : 3 , "n_atom_properties" : 158 , "n_bond_properties" : 7 ,
39- "n_molecule_properties" : 200 })
38+ ckpt_path ,
39+ map_location = torch .device (self .device ),
40+ criterion = None ,
41+ strict = False ,
42+ metrics = dict (train = dict (), test = dict (), validation = dict ()),
43+ pretrained_checkpoint = None ,
44+ config = {
45+ "in_length" : 256 ,
46+ "hidden_length" : 512 ,
47+ "dropout_rate" : 0.1 ,
48+ "n_conv_layers" : 3 ,
49+ "n_linear_layers" : 3 ,
50+ "n_atom_properties" : 158 ,
51+ "n_bond_properties" : 7 ,
52+ "n_molecule_properties" : 200 ,
53+ },
54+ )
4055 model .eval ()
4156 return model
4257
@@ -55,14 +70,21 @@ def read_smiles(self, smiles):
5570 # use default value if we meet an unseen value
5671 if isinstance (prop .encoder , IndexEncoder ):
5772 if str (value ) in prop .encoder .cache :
58- index = prop .encoder .cache .index (str (value )) + prop .encoder .offset
73+ index = (
74+ prop .encoder .cache .index (str (value )) + prop .encoder .offset
75+ )
5976 else :
6077 index = 0
61- print (f"Unknown property value { value } for property { prop } at smiles { smiles } " )
78+ print (
79+ f"Unknown property value { value } for property { prop } at smiles { smiles } "
80+ )
6281 if isinstance (prop .encoder , OneHotEncoder ):
63- encoded_values .append (torch .nn .functional .one_hot (
64- torch .tensor (index ), num_classes = prop .encoder .get_encoding_length ()
65- ))
82+ encoded_values .append (
83+ torch .nn .functional .one_hot (
84+ torch .tensor (index ),
85+ num_classes = prop .encoder .get_encoding_length (),
86+ )
87+ )
6688 else :
6789 encoded_values .append (torch .tensor ([index ]))
6890
@@ -77,9 +99,7 @@ def read_smiles(self, smiles):
7799 if len (encoded_values .size ()) == 1 :
78100 encoded_values = encoded_values .unsqueeze (1 )
79101 else :
80- encoded_values = torch .zeros (
81- (0 , prop .encoder .get_encoding_length ())
82- )
102+ encoded_values = torch .zeros ((0 , prop .encoder .get_encoding_length ()))
83103 if isinstance (prop , p .AtomProperty ):
84104 x = torch .cat ([x , encoded_values ], dim = 1 )
85105 elif isinstance (prop , p .BondProperty ):
@@ -93,4 +113,4 @@ def read_smiles(self, smiles):
93113 edge_attr = edge_attr ,
94114 molecule_attr = molecule_attr ,
95115 )
96- return d
116+ return d
0 commit comments