Skip to content

Commit 173788c

Browse files
committed
fix pred pipe func
1 parent a3284b6 commit 173788c

File tree

1 file changed

+86
-34
lines changed
  • chebai_graph/preprocessing/datasets

1 file changed

+86
-34
lines changed

chebai_graph/preprocessing/datasets/chebi.py

Lines changed: 86 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,62 @@ def _after_setup(self, **kwargs) -> None:
184184
self._setup_properties()
185185
super()._after_setup(**kwargs)
186186

187+
def _process_input_for_prediction(
188+
self,
189+
smiles_list: list[str],
190+
model_hparams: Optional[dict] = None,
191+
) -> list:
192+
data_df = self._process_smiles_and_props(smiles_list)
193+
data_df["features"] = data_df.apply(
194+
lambda row: self._merge_props_into_base(row), axis=1
195+
)
196+
197+
# apply transformation, e.g. masking for pretraining task
198+
if self.transform is not None:
199+
data_df["features"] = data_df["features"].apply(self.transform)
200+
201+
return data_df.to_dict("records")
202+
203+
def _process_smiles_and_props(self, smiles_list: list[str]) -> pd.DataFrame:
204+
"""
205+
Process SMILES strings and compute molecular properties.
206+
"""
207+
data = [
208+
self.reader.to_data(
209+
{"id": f"smiles_{idx}", "features": smiles, "labels": None}
210+
)
211+
for idx, smiles in enumerate(smiles_list)
212+
]
213+
# element of data is a dict with 'id' and 'features' (GeomData)
214+
# GeomData has only edge_index filled but node and edges features are empty.
215+
216+
assert len(data) == len(smiles_list), "Data length mismatch."
217+
data_df = pd.DataFrame(data)
218+
219+
props: list[dict] = []
220+
for data_row in data_df.itertuples(index=True):
221+
row_prop_dict: dict = {}
222+
for property in self.properties:
223+
property.encoder.eval = True
224+
property_value = self.reader.read_property(
225+
smiles_list[data_row.Index], property
226+
)
227+
if property_value is None or len(property_value) == 0:
228+
encoded_value = None
229+
else:
230+
encoded_value = torch.stack(
231+
[property.encoder.encode(v) for v in property_value]
232+
)
233+
if len(encoded_value.shape) == 3:
234+
encoded_value = encoded_value.squeeze(0)
235+
row_prop_dict[property.name] = encoded_value
236+
row_prop_dict["ident"] = data_row.ident
237+
props.append(row_prop_dict)
238+
239+
property_df = pd.DataFrame(props)
240+
data_df = data_df.merge(property_df, on="ident", how="left")
241+
return data_df
242+
187243

188244
class GraphPropertiesMixIn(DataPropertiesSetter, ABC):
189245
def __init__(
@@ -361,40 +417,6 @@ def load_processed_data(
361417

362418
return base_df[base_data[0].keys()].to_dict("records")
363419

364-
def _process_input_for_prediction(self, smiles_list: list[str]) -> list:
365-
data = [
366-
self.reader.to_data(
367-
{"id": f"smiles_{idx}", "features": smiles, "labels": None}
368-
)
369-
for idx, smiles in enumerate(smiles_list)
370-
]
371-
# element of data is a dict with 'id' and 'features' (GeomData)
372-
# GeomData has only edge_index filled but node and edges features are empty.
373-
374-
assert len(data) == len(smiles_list), "Data length mismatch."
375-
data_df = pd.DataFrame(data)
376-
377-
for idx, data_row in data_df.itertuples(index=True):
378-
property_data = data_row
379-
for property in self.properties:
380-
property.encoder.eval = True
381-
property_value = self.reader.read_property(smiles_list[idx], property)
382-
if property_value is None or len(property_value) == 0:
383-
encoded_value = None
384-
else:
385-
encoded_value = torch.stack(
386-
[property.encoder.encode(v) for v in property_value]
387-
)
388-
if len(encoded_value.shape) == 3:
389-
encoded_value = encoded_value.squeeze(0)
390-
property_data[property.name] = encoded_value
391-
392-
property_data["features"] = property_data.apply(
393-
lambda row: self._merge_props_into_base(row), axis=1
394-
)
395-
396-
return data_df.to_dict("records")
397-
398420

399421
class GraphPropAsPerNodeType(DataPropertiesSetter, ABC):
400422
def __init__(self, properties=None, transform=None, **kwargs):
@@ -605,6 +627,36 @@ def _merge_props_into_base(
605627
is_graph_node=is_graph_node,
606628
)
607629

630+
def _process_input_for_prediction(
631+
self,
632+
smiles_list: list[str],
633+
model_hparams: Optional[dict] = None,
634+
) -> list:
635+
if (
636+
model_hparams is None
637+
or "in_channels" not in model_hparams["config"]
638+
or model_hparams["config"]["in_channels"] is None
639+
):
640+
raise ValueError(
641+
f"model_hparams must be provided for data class: {self.__class__.__name__}"
642+
f" which should contain 'in_channels' key with valid value in 'config' dictionary."
643+
)
644+
645+
max_len_node_properties = int(model_hparams["config"]["in_channels"])
646+
# Determine max_len_node_properties based on in_channels
647+
648+
data_df = self._process_smiles_and_props(smiles_list)
649+
data_df["features"] = data_df.apply(
650+
lambda row: self._merge_props_into_base(row, max_len_node_properties),
651+
axis=1,
652+
)
653+
654+
# apply transformation, e.g. masking for pretraining task
655+
if self.transform is not None:
656+
data_df["features"] = data_df["features"].apply(self.transform)
657+
658+
return data_df.to_dict("records")
659+
608660

609661
class ChEBI50_StaticGNI(DataPropertiesSetter, ChEBIOver50):
610662
READER = RandomFeatureInitializationReader

0 commit comments

Comments
 (0)