Skip to content

Commit c00b65c

Browse files
committed
add class for props as per node type
1 parent 5bb5f0f commit c00b65c

File tree

1 file changed

+154
-4
lines changed
  • chebai_graph/preprocessing/datasets

1 file changed

+154
-4
lines changed

chebai_graph/preprocessing/datasets/chebi.py

Lines changed: 154 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,13 @@
1515
from torch_geometric.data.data import Data as GeomData
1616

1717
from chebai_graph.preprocessing.properties import (
18+
AllNodeTypeProperty,
19+
AtomNodeTypeProperty,
1820
AtomProperty,
1921
BondProperty,
22+
FGNodeTypeProperty,
2023
MolecularProperty,
24+
MoleculeProperty,
2125
)
2226
from chebai_graph.preprocessing.reader import (
2327
AtomFGReader_NoFGEdges_WithGraphNode,
@@ -41,7 +45,7 @@ def __init__(self, **kwargs):
4145
super().__init__(**kwargs)
4246

4347

44-
class GraphPropertiesMixIn(ChEBIOverX, ABC):
48+
class DataPropertiesSetter(ChEBIOverX, ABC):
4549
"""Mixin for adding molecular property encodings to graph-based ChEBI datasets."""
4650

4751
READER = GraphPropertyReader
@@ -172,6 +176,8 @@ def _after_setup(self, **kwargs) -> None:
172176
self._setup_properties()
173177
super()._after_setup(**kwargs)
174178

179+
180+
class GraphPropertiesMixIn(DataPropertiesSetter, ABC):
175181
def _merge_props_into_base(self, row: pd.Series) -> GeomData:
176182
"""
177183
Merge encoded molecular properties into the GeomData object.
@@ -208,8 +214,10 @@ def _merge_props_into_base(self, row: pd.Series) -> GeomData:
208214
[edge_attr, torch.cat([property_values, property_values], dim=0)],
209215
dim=1,
210216
)
211-
else:
217+
elif isinstance(property, MoleculeProperty):
212218
molecule_attr = torch.cat([molecule_attr, property_values], dim=1)
219+
else:
220+
raise TypeError(f"Unsupported property type: {type(property).__name__}")
213221

214222
return GeomData(
215223
x=x,
@@ -261,11 +269,153 @@ def load_processed_data_from_file(self, filename: str) -> list[dict]:
261269
f"Finished loading dataset from properties.\nEncoding lengths: {prop_lengths}\n"
262270
f"Use n_atom_properties: {sum(p.encoder.get_encoding_length() for p in self.properties if isinstance(p, AtomProperty))}, "
263271
f"n_bond_properties: {sum(p.encoder.get_encoding_length() for p in self.properties if isinstance(p, BondProperty))}, "
264-
f"n_molecule_properties: {sum(p.encoder.get_encoding_length() for p in self.properties if not isinstance(p, (AtomProperty, BondProperty)))}"
272+
f"n_molecule_properties: {sum(p.encoder.get_encoding_length() for p in self.properties if isinstance(p, MoleculeProperty))}"
273+
)
274+
275+
return base_df[base_data[0].keys()].to_dict("records")
276+
277+
278+
class GraphPropertiesAsPerNodeType(DataPropertiesSetter, ABC):
279+
READER = AtomFGReader_WithFGEdges_WithGraphNode
280+
281+
def load_processed_data_from_file(self, filename: str) -> list[dict]:
282+
"""
283+
Load dataset and merge cached properties into base features.
284+
285+
Args:
286+
filename: The path to the file to load.
287+
288+
Returns:
289+
List of data entries, each a dictionary.
290+
"""
291+
base_data = super().load_processed_data_from_file(filename)
292+
base_df = pd.DataFrame(base_data)
293+
294+
props_categories = {
295+
"AllNodeTypeProperties": [],
296+
"FGNodeTypeProperties": [],
297+
"AtomNodeTypeProperties": [],
298+
"GraphNodeTypeProperties": [],
299+
"BondProperties": [],
300+
}
301+
n_atom_node_properties, n_fg_node_properties = 0, 0
302+
n_bond_properties, n_graph_node_properties = 0, 0
303+
prop_lengths = []
304+
for prop in self.properties:
305+
prop_length = prop.encoder.get_encoding_length()
306+
prop_name = prop.name
307+
prop_lengths.append((prop_name, prop_length))
308+
if isinstance(prop, AllNodeTypeProperty):
309+
n_atom_node_properties += prop_length
310+
n_fg_node_properties += prop_length
311+
props_categories["AllNodeTypeProperties"].append(prop_name)
312+
elif isinstance(prop, FGNodeTypeProperty):
313+
n_fg_node_properties += prop_length
314+
props_categories["FGNodeTypeProperties"].append(prop_name)
315+
elif isinstance(prop, AtomNodeTypeProperty):
316+
n_atom_node_properties += prop_length
317+
props_categories["AtomNodeTypeProperties"].append(prop_name)
318+
elif isinstance(prop, BondProperty):
319+
n_bond_properties += prop_length
320+
props_categories["BondProperties"].append(prop_name)
321+
elif isinstance(prop, MoleculeProperty):
322+
# molecule props will be used as graph node props
323+
n_graph_node_properties += prop_length
324+
props_categories["GraphNodeTypeProperties"].append(prop_name)
325+
else:
326+
raise TypeError(f"Unsupported property type: {type(prop).__name__}")
327+
328+
n_atom_properties = max(
329+
n_atom_node_properties, n_fg_node_properties, n_graph_node_properties
265330
)
331+
rank_zero_info(
332+
f"Finished loading dataset from properties.\nEncoding lengths: {prop_lengths}\n"
333+
f"Properties Categories {props_categories}\n"
334+
f"n_atom_node_properties: {n_atom_node_properties}, "
335+
f"n_fg_node_properties: {n_fg_node_properties}, "
336+
f"n_bond_properties: {n_bond_properties}, "
337+
f"n_graph_node_properties: {n_graph_node_properties}\n"
338+
f"Use n_atom_properties: {n_atom_properties}, n_bond_properties: {n_bond_properties}, n_molecule_properties: 0"
339+
)
340+
341+
for property in self.properties:
342+
property_data = torch.load(
343+
self.get_property_path(property), weights_only=False
344+
)
345+
if len(property_data[0][property.name].shape) > 1:
346+
property.encoder.set_encoding_length(
347+
property_data[0][property.name].shape[1]
348+
)
349+
350+
property_df = pd.DataFrame(property_data)
351+
property_df.rename(
352+
columns={property.name: f"{property.name}"}, inplace=True
353+
)
354+
base_df = base_df.merge(property_df, on="ident", how="left")
355+
356+
base_df["features"] = base_df.apply(
357+
lambda row: self._merge_props_into_base(row), axis=1
358+
)
359+
360+
# apply transformation, e.g. masking for pretraining task
361+
if self.transform is not None:
362+
base_df["features"] = base_df["features"].apply(self.transform)
266363

267364
return base_df[base_data[0].keys()].to_dict("records")
268365

366+
def _merge_props_into_base(self, row: pd.Series) -> GeomData:
367+
"""
368+
Merge encoded molecular properties into the GeomData object.
369+
370+
Args:
371+
row: A dictionary containing 'features' and encoded properties.
372+
373+
Returns:
374+
A GeomData object with merged features.
375+
"""
376+
geom_data = row["features"]
377+
assert isinstance(geom_data, GeomData)
378+
is_atom_node = geom_data.is_atom_node
379+
assert is_atom_node is not None, "`is_atom_node` must be set in the geom_data"
380+
is_graph_node = geom_data.is_graph_node
381+
assert is_graph_node is not None, "`is_graph_node` must be set in the geom_data"
382+
383+
edge_attr = geom_data.edge_attr
384+
x = geom_data.x
385+
molecule_attr = torch.empty((1, 0))
386+
387+
for property in self.properties:
388+
property_values = row[f"{property.name}"]
389+
if isinstance(property_values, torch.Tensor):
390+
if len(property_values.size()) == 0:
391+
property_values = property_values.unsqueeze(0)
392+
if len(property_values.size()) == 1:
393+
property_values = property_values.unsqueeze(1)
394+
else:
395+
property_values = torch.zeros(
396+
(0, property.encoder.get_encoding_length())
397+
)
398+
399+
if isinstance(property, AtomProperty):
400+
x = torch.cat([x, property_values], dim=1)
401+
elif isinstance(property, BondProperty):
402+
# Concat/Duplicate properties values for undirected graph as `edge_index` has first src to tgt edges, then tgt to src edges
403+
edge_attr = torch.cat(
404+
[edge_attr, torch.cat([property_values, property_values], dim=0)],
405+
dim=1,
406+
)
407+
elif isinstance(property, MoleculeProperty):
408+
molecule_attr = torch.cat([molecule_attr, property_values], dim=1)
409+
else:
410+
raise TypeError(f"Unsupported property type: {type(property).__name__}")
411+
412+
return GeomData(
413+
x=x,
414+
edge_index=geom_data.edge_index,
415+
edge_attr=edge_attr,
416+
molecule_attr=molecule_attr,
417+
)
418+
269419

270420
class ChEBI50GraphProperties(GraphPropertiesMixIn, ChEBIOver50):
271421
"""ChEBIOver50 dataset with molecular property encodings."""
@@ -310,7 +460,7 @@ def _merge_props_into_base(self, row: pd.Series) -> GeomData:
310460
data = super()._merge_props_into_base(row)
311461
return self._add_graph_node_mask(data, row)
312462

313-
def _add_graph_node_mask(self, data: GeomData, row) -> GeomData:
463+
def _add_graph_node_mask(self, data: GeomData, row: pd.Series) -> GeomData:
314464
"""
315465
Add a graph node mask to the GeomData object.
316466

0 commit comments

Comments
 (0)