Skip to content

Commit e4e5397

Browse files
committed
modify merge method to have diff props for diff nodes
1 parent 54b0d07 commit e4e5397

File tree

1 file changed

+70
-11
lines changed
  • chebai_graph/preprocessing/datasets

1 file changed

+70
-11
lines changed

chebai_graph/preprocessing/datasets/chebi.py

Lines changed: 70 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -275,8 +275,19 @@ def load_processed_data_from_file(self, filename: str) -> list[dict]:
275275
return base_df[base_data[0].keys()].to_dict("records")
276276

277277

278-
class GraphPropertiesAsPerNodeType(DataPropertiesSetter, ABC):
279-
READER = AtomFGReader_WithFGEdges_WithGraphNode
278+
class GraphPropAsPerNodeType(DataPropertiesSetter, ABC):
279+
def __init__(self, properties=None, transform=None, **kwargs):
280+
super().__init__(properties, transform, **kwargs)
281+
# Sort properties so that AllNodeTypeProperty instances come first, rest of the properties order remain same
282+
first = [
283+
prop for prop in self.properties if isinstance(prop, AllNodeTypeProperty)
284+
]
285+
rest = [
286+
prop
287+
for prop in self.properties
288+
if not isinstance(prop, AllNodeTypeProperty)
289+
]
290+
self.properties = first + rest
280291

281292
def load_processed_data_from_file(self, filename: str) -> list[dict]:
282293
"""
@@ -308,6 +319,7 @@ def load_processed_data_from_file(self, filename: str) -> list[dict]:
308319
if isinstance(prop, AllNodeTypeProperty):
309320
n_atom_node_properties += prop_length
310321
n_fg_node_properties += prop_length
322+
n_graph_node_properties += prop_length
311323
props_categories["AllNodeTypeProperties"].append(prop_name)
312324
elif isinstance(prop, FGNodeTypeProperty):
313325
n_fg_node_properties += prop_length
@@ -354,7 +366,11 @@ def load_processed_data_from_file(self, filename: str) -> list[dict]:
354366
base_df = base_df.merge(property_df, on="ident", how="left")
355367

356368
base_df["features"] = base_df.apply(
357-
lambda row: self._merge_props_into_base(row), axis=1
369+
lambda row: self._merge_props_into_base(
370+
row,
371+
max_len_node_properties=n_atom_properties,
372+
),
373+
axis=1,
358374
)
359375

360376
# apply transformation, e.g. masking for pretraining task
@@ -363,7 +379,9 @@ def load_processed_data_from_file(self, filename: str) -> list[dict]:
363379

364380
return base_df[base_data[0].keys()].to_dict("records")
365381

366-
def _merge_props_into_base(self, row: pd.Series) -> GeomData:
382+
def _merge_props_into_base(
383+
self, row: pd.Series, max_len_node_properties: int
384+
) -> GeomData:
367385
"""
368386
Merge encoded molecular properties into the GeomData object.
369387
@@ -375,14 +393,24 @@ def _merge_props_into_base(self, row: pd.Series) -> GeomData:
375393
"""
376394
geom_data = row["features"]
377395
assert isinstance(geom_data, GeomData)
396+
378397
is_atom_node = geom_data.is_atom_node
379398
assert is_atom_node is not None, "`is_atom_node` must be set in the geom_data"
380399
is_graph_node = geom_data.is_graph_node
381400
assert is_graph_node is not None, "`is_graph_node` must be set in the geom_data"
382401

402+
is_fg_node = ~is_atom_node & ~is_graph_node
403+
num_nodes = geom_data.x.size(0)
383404
edge_attr = geom_data.edge_attr
384-
x = geom_data.x
385-
molecule_attr = torch.empty((1, 0))
405+
406+
# Initialize node feature matrix
407+
assert (
408+
max_len_node_properties is not None
409+
), "Maximum len of node properties should not be None"
410+
x = torch.zeros((num_nodes, max_len_node_properties))
411+
412+
# Track column offsets for each node type
413+
atom_offset, fg_offset, graph_offset = 0, 0, 0
386414

387415
for property in self.properties:
388416
property_values = row[f"{property.name}"]
@@ -396,24 +424,51 @@ def _merge_props_into_base(self, row: pd.Series) -> GeomData:
396424
(0, property.encoder.get_encoding_length())
397425
)
398426

399-
if isinstance(property, AtomProperty):
400-
x = torch.cat([x, property_values], dim=1)
427+
enc_len = property_values.shape[1]
428+
# -------------- Node properties ---------------
429+
if isinstance(property, AllNodeTypeProperty):
430+
x[:, atom_offset : atom_offset + enc_len] = property_values
431+
atom_offset += enc_len
432+
fg_offset += enc_len
433+
graph_offset += enc_len
434+
435+
elif isinstance(property, AtomNodeTypeProperty):
436+
x[is_atom_node, atom_offset : atom_offset + enc_len] = property_values[
437+
is_atom_node
438+
]
439+
atom_offset += enc_len
440+
441+
elif isinstance(property, FGNodeTypeProperty):
442+
x[is_fg_node, fg_offset : fg_offset + enc_len] = property_values[
443+
is_fg_node
444+
]
445+
fg_offset += enc_len
446+
447+
elif isinstance(property, MoleculeProperty):
448+
x[is_graph_node, graph_offset : graph_offset + enc_len] = (
449+
property_values[is_graph_node]
450+
)
451+
graph_offset += enc_len
452+
453+
# ------------- Bond Properties --------------
401454
elif isinstance(property, BondProperty):
402455
# Concat/Duplicate properties values for undirected graph as `edge_index` has first src to tgt edges, then tgt to src edges
403456
edge_attr = torch.cat(
404457
[edge_attr, torch.cat([property_values, property_values], dim=0)],
405458
dim=1,
406459
)
407-
elif isinstance(property, MoleculeProperty):
408-
molecule_attr = torch.cat([molecule_attr, property_values], dim=1)
409460
else:
410461
raise TypeError(f"Unsupported property type: {type(property).__name__}")
411462

463+
total_used_columns = max(atom_offset, fg_offset, graph_offset)
464+
assert (
465+
total_used_columns <= max_len_node_properties
466+
), f"Used {total_used_columns} columns, but max allowed is {max_len_node_properties}"
467+
412468
return GeomData(
413469
x=x,
414470
edge_index=geom_data.edge_index,
415471
edge_attr=edge_attr,
416-
molecule_attr=molecule_attr,
417472
)
418473

419474

@@ -507,3 +562,7 @@ class ChEBI50_Atom_WGNOnly_GraphProp(AugGraphPropMixIn_WithGraphNode, ChEBIOver5
507562
"""ChEBIOver50 with atom-level nodes and graph node only."""
508563

509564
READER = AtomReader_WithGraphNodeOnly
565+
566+
567+
class ChEBI50_WFGE_WGN_AsPerNodeType(GraphPropAsPerNodeType, ChEBIOver50):
568+
READER = AtomFGReader_WithFGEdges_WithGraphNode

0 commit comments

Comments
 (0)