Skip to content

Commit 57312a0

Browse files
committed
fix data processing
1 parent 9c6d915 commit 57312a0

File tree

4 files changed

+25
-7
lines changed

4 files changed

+25
-7
lines changed

chebai_graph/models/base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,7 @@ def forward(self, batch: dict) -> torch.Tensor:
184184
torch.Tensor: Predicted output.
185185
"""
186186
graph_data = batch["features"][0]
187+
graph_data.to(self.device)
187188
assert isinstance(graph_data, GraphData)
188189
a = self.gnn(batch)
189190
a = scatter_add(a, graph_data.batch, dim=0)

chebai_graph/preprocessing/datasets/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,14 @@
1212
ChEBI50_WFGE_WGN_GraphProp,
1313
ChEBI50GraphData,
1414
ChEBI50GraphProperties,
15+
ChEBI100GraphProperties
1516
)
1617
from .pubchem import PubChemGraphProperties
1718

1819
__all__ = [
1920
"ChEBI50GraphFGAugmentorReader",
2021
"ChEBI50GraphProperties",
22+
"ChEBI100GraphProperties",
2123
"ChEBI50GraphData",
2224
"PubChemGraphProperties",
2325
"ChEBI50_Atom_WGNOnly_GraphProp",

chebai_graph/preprocessing/datasets/chebi.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from abc import ABC
33
from collections.abc import Callable
44
from pprint import pformat
5+
from typing import Optional
56

67
import pandas as pd
78
import torch
@@ -281,7 +282,7 @@ def _merge_props_into_base(self, row: pd.Series) -> GeomData:
281282
molecule_attr=molecule_attr,
282283
)
283284

284-
def load_processed_data_from_file(self, filename: str) -> list[dict]:
285+
def load_processed_data(self, kind: Optional[str] = None, filename: Optional[str] = None) -> list[dict]:
285286
"""
286287
Load dataset and merge cached properties into base features.
287288
@@ -291,7 +292,7 @@ def load_processed_data_from_file(self, filename: str) -> list[dict]:
291292
Returns:
292293
List of data entries, each a dictionary.
293294
"""
294-
base_data = super().load_processed_data_from_file(filename)
295+
base_data = super().load_processed_data(kind, filename)
295296
base_df = pd.DataFrame(base_data)
296297

297298
for property in self.properties:
@@ -379,7 +380,7 @@ def __init__(self, properties=None, transform=None, **kwargs):
379380
f"Data module uses these properties (ordered): {', '.join([str(p) for p in self.properties])}",
380381
)
381382

382-
def load_processed_data_from_file(self, filename: str) -> list[dict]:
383+
def load_processed_data(self, kind: Optional[str] = None, filename: Optional[str] = None) -> list[dict]:
383384
"""
384385
Load dataset and merge cached properties into base features.
385386
@@ -389,9 +390,8 @@ def load_processed_data_from_file(self, filename: str) -> list[dict]:
389390
Returns:
390391
List of data entries, each a dictionary.
391392
"""
392-
base_data = super().load_processed_data_from_file(filename)
393+
base_data = super().load_processed_data(kind, filename)
393394
base_df = pd.DataFrame(base_data)
394-
395395
props_categories = {
396396
"AllNodeTypeProperties": [],
397397
"FGNodeTypeProperties": [],
@@ -442,6 +442,7 @@ def load_processed_data_from_file(self, filename: str) -> list[dict]:
442442
)
443443

444444
for property in self.properties:
445+
rank_zero_info(f"Loading property {property.name}...")
445446
property_data = torch.load(
446447
self.get_property_path(property), weights_only=False
447448
)

chebai_graph/preprocessing/property_encoder.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,10 @@ class PropertyEncoder(abc.ABC):
1616
**kwargs: Additional keyword arguments.
1717
"""
1818

19-
def __init__(self, property, **kwargs) -> None:
19+
def __init__(self, property, eval=False, **kwargs) -> None:
2020
self.property = property
2121
self._encoding_length: int = 1
22+
self.eval = eval # if True, do not update cache (for index encoder)
2223

2324
@property
2425
def name(self) -> str:
@@ -149,7 +150,11 @@ def encode(self, token: str | None) -> torch.Tensor:
149150
if token is None:
150151
self._count_for_unk_token += 1
151152
return torch.tensor([self._unk_token_idx])
152-
153+
154+
if self.eval and str(token) not in self.cache:
155+
self._count_for_unk_token += 1
156+
return torch.tensor([self._unk_token_idx])
157+
153158
if str(token) not in self.cache:
154159
self.cache[str(token)] = len(self.cache)
155160
return torch.tensor([self.cache[str(token)] + self.offset])
@@ -213,6 +218,15 @@ def encode(self, token: str | None) -> torch.Tensor:
213218
Returns:
214219
One-hot encoded tensor of shape (1, encoding_length).
215220
"""
221+
if self.eval:
222+
if token is None or str(token) not in self.cache:
223+
self._count_for_unk_token += 1
224+
return torch.zeros(self.get_encoding_length(), dtype=torch.int64)
225+
index = self.cache[str(token)] + self.offset
226+
return torch.nn.functional.one_hot(
227+
torch.tensor(index), num_classes=self.get_encoding_length()
228+
)
229+
216230
if token not in self.tokens_dict:
217231
self._count_for_unk_token += 1
218232
return torch.zeros(1, self.get_encoding_length(), dtype=torch.int64)

0 commit comments

Comments
 (0)