@@ -184,6 +184,62 @@ def _after_setup(self, **kwargs) -> None:
184184 self ._setup_properties ()
185185 super ()._after_setup (** kwargs )
186186
187+ def _process_input_for_prediction (
188+ self ,
189+ smiles_list : list [str ],
190+ model_hparams : Optional [dict ] = None ,
191+ ) -> list :
192+ data_df = self ._process_smiles_and_props (smiles_list )
193+ data_df ["features" ] = data_df .apply (
194+ lambda row : self ._merge_props_into_base (row ), axis = 1
195+ )
196+
197+ # apply transformation, e.g. masking for pretraining task
198+ if self .transform is not None :
199+ data_df ["features" ] = data_df ["features" ].apply (self .transform )
200+
201+ return data_df .to_dict ("records" )
202+
203+ def _process_smiles_and_props (self , smiles_list : list [str ]) -> pd .DataFrame :
204+ """
205+ Process SMILES strings and compute molecular properties.
206+ """
207+ data = [
208+ self .reader .to_data (
209+ {"id" : f"smiles_{ idx } " , "features" : smiles , "labels" : None }
210+ )
211+ for idx , smiles in enumerate (smiles_list )
212+ ]
213+ # element of data is a dict with 'id' and 'features' (GeomData)
214+ # GeomData has only edge_index filled but node and edges features are empty.
215+
216+ assert len (data ) == len (smiles_list ), "Data length mismatch."
217+ data_df = pd .DataFrame (data )
218+
219+ props : list [dict ] = []
220+ for data_row in data_df .itertuples (index = True ):
221+ row_prop_dict : dict = {}
222+ for property in self .properties :
223+ property .encoder .eval = True
224+ property_value = self .reader .read_property (
225+ smiles_list [data_row .Index ], property
226+ )
227+ if property_value is None or len (property_value ) == 0 :
228+ encoded_value = None
229+ else :
230+ encoded_value = torch .stack (
231+ [property .encoder .encode (v ) for v in property_value ]
232+ )
233+ if len (encoded_value .shape ) == 3 :
234+ encoded_value = encoded_value .squeeze (0 )
235+ row_prop_dict [property .name ] = encoded_value
236+ row_prop_dict ["ident" ] = data_row .ident
237+ props .append (row_prop_dict )
238+
239+ property_df = pd .DataFrame (props )
240+ data_df = data_df .merge (property_df , on = "ident" , how = "left" )
241+ return data_df
242+
187243
188244class GraphPropertiesMixIn (DataPropertiesSetter , ABC ):
189245 def __init__ (
@@ -361,40 +417,6 @@ def load_processed_data(
361417
362418 return base_df [base_data [0 ].keys ()].to_dict ("records" )
363419
364- def _process_input_for_prediction (self , smiles_list : list [str ]) -> list :
365- data = [
366- self .reader .to_data (
367- {"id" : f"smiles_{ idx } " , "features" : smiles , "labels" : None }
368- )
369- for idx , smiles in enumerate (smiles_list )
370- ]
371- # element of data is a dict with 'id' and 'features' (GeomData)
372- # GeomData has only edge_index filled but node and edges features are empty.
373-
374- assert len (data ) == len (smiles_list ), "Data length mismatch."
375- data_df = pd .DataFrame (data )
376-
377- for idx , data_row in data_df .itertuples (index = True ):
378- property_data = data_row
379- for property in self .properties :
380- property .encoder .eval = True
381- property_value = self .reader .read_property (smiles_list [idx ], property )
382- if property_value is None or len (property_value ) == 0 :
383- encoded_value = None
384- else :
385- encoded_value = torch .stack (
386- [property .encoder .encode (v ) for v in property_value ]
387- )
388- if len (encoded_value .shape ) == 3 :
389- encoded_value = encoded_value .squeeze (0 )
390- property_data [property .name ] = encoded_value
391-
392- property_data ["features" ] = property_data .apply (
393- lambda row : self ._merge_props_into_base (row ), axis = 1
394- )
395-
396- return data_df .to_dict ("records" )
397-
398420
399421class GraphPropAsPerNodeType (DataPropertiesSetter , ABC ):
400422 def __init__ (self , properties = None , transform = None , ** kwargs ):
@@ -605,6 +627,36 @@ def _merge_props_into_base(
605627 is_graph_node = is_graph_node ,
606628 )
607629
630+ def _process_input_for_prediction (
631+ self ,
632+ smiles_list : list [str ],
633+ model_hparams : Optional [dict ] = None ,
634+ ) -> list :
635+ if (
636+ model_hparams is None
637+ or "in_channels" not in model_hparams ["config" ]
638+ or model_hparams ["config" ]["in_channels" ] is None
639+ ):
640+ raise ValueError (
641+ f"model_hparams must be provided for data class: { self .__class__ .__name__ } "
642+ f" which should contain 'in_channels' key with valid value in 'config' dictionary."
643+ )
644+
645+ max_len_node_properties = int (model_hparams ["config" ]["in_channels" ])
646+ # Determine max_len_node_properties based on in_channels
647+
648+ data_df = self ._process_smiles_and_props (smiles_list )
649+ data_df ["features" ] = data_df .apply (
650+ lambda row : self ._merge_props_into_base (row , max_len_node_properties ),
651+ axis = 1 ,
652+ )
653+
654+ # apply transformation, e.g. masking for pretraining task
655+ if self .transform is not None :
656+ data_df ["features" ] = data_df ["features" ].apply (self .transform )
657+
658+ return data_df .to_dict ("records" )
659+
608660
609661class ChEBI50_StaticGNI (DataPropertiesSetter , ChEBIOver50 ):
610662 READER = RandomFeatureInitializationReader
0 commit comments