Skip to content

Commit 0ef9e87

Browse files
authored
Add eval mode to property encoder
Add eval mode to property encoder
2 parents 9279bd7 + 355085f commit 0ef9e87

File tree

4 files changed

+28
-6
lines changed

4 files changed

+28
-6
lines changed

chebai_graph/models/base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,7 @@ def forward(self, batch: dict) -> torch.Tensor:
184184
torch.Tensor: Predicted output.
185185
"""
186186
graph_data = batch["features"][0]
187+
graph_data.to(self.device)
187188
assert isinstance(graph_data, GraphData)
188189
a = self.gnn(batch)
189190
a = scatter_add(a, graph_data.batch, dim=0)

chebai_graph/preprocessing/datasets/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,14 @@
1212
ChEBI50_WFGE_WGN_GraphProp,
1313
ChEBI50GraphData,
1414
ChEBI50GraphProperties,
15+
ChEBI100GraphProperties,
1516
)
1617
from .pubchem import PubChemGraphProperties
1718

1819
__all__ = [
1920
"ChEBI50GraphFGAugmentorReader",
2021
"ChEBI50GraphProperties",
22+
"ChEBI100GraphProperties",
2123
"ChEBI50GraphData",
2224
"PubChemGraphProperties",
2325
"ChEBI50_Atom_WGNOnly_GraphProp",

chebai_graph/preprocessing/datasets/chebi.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from abc import ABC
33
from collections.abc import Callable
44
from pprint import pformat
5+
from typing import Optional
56

67
import pandas as pd
78
import torch
@@ -281,7 +282,9 @@ def _merge_props_into_base(self, row: pd.Series) -> GeomData:
281282
molecule_attr=molecule_attr,
282283
)
283284

284-
def load_processed_data_from_file(self, filename: str) -> list[dict]:
285+
def load_processed_data(
286+
self, kind: Optional[str] = None, filename: Optional[str] = None
287+
) -> list[dict]:
285288
"""
286289
Load dataset and merge cached properties into base features.
287290
@@ -291,7 +294,7 @@ def load_processed_data_from_file(self, filename: str) -> list[dict]:
291294
Returns:
292295
List of data entries, each a dictionary.
293296
"""
294-
base_data = super().load_processed_data_from_file(filename)
297+
base_data = super().load_processed_data(kind, filename)
295298
base_df = pd.DataFrame(base_data)
296299

297300
for property in self.properties:
@@ -379,7 +382,9 @@ def __init__(self, properties=None, transform=None, **kwargs):
379382
f"Data module uses these properties (ordered): {', '.join([str(p) for p in self.properties])}",
380383
)
381384

382-
def load_processed_data_from_file(self, filename: str) -> list[dict]:
385+
def load_processed_data(
386+
self, kind: Optional[str] = None, filename: Optional[str] = None
387+
) -> list[dict]:
383388
"""
384389
Load dataset and merge cached properties into base features.
385390
@@ -389,9 +394,8 @@ def load_processed_data_from_file(self, filename: str) -> list[dict]:
389394
Returns:
390395
List of data entries, each a dictionary.
391396
"""
392-
base_data = super().load_processed_data_from_file(filename)
397+
base_data = super().load_processed_data(kind, filename)
393398
base_df = pd.DataFrame(base_data)
394-
395399
props_categories = {
396400
"AllNodeTypeProperties": [],
397401
"FGNodeTypeProperties": [],
@@ -442,6 +446,7 @@ def load_processed_data_from_file(self, filename: str) -> list[dict]:
442446
)
443447

444448
for property in self.properties:
449+
rank_zero_info(f"Loading property {property.name}...")
445450
property_data = torch.load(
446451
self.get_property_path(property), weights_only=False
447452
)

chebai_graph/preprocessing/property_encoder.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,10 @@ class PropertyEncoder(abc.ABC):
1616
**kwargs: Additional keyword arguments.
1717
"""
1818

19-
def __init__(self, property, **kwargs) -> None:
19+
def __init__(self, property, eval=False, **kwargs) -> None:
2020
self.property = property
2121
self._encoding_length: int = 1
22+
self.eval = eval # if True, do not update cache (for index encoder)
2223

2324
@property
2425
def name(self) -> str:
@@ -150,6 +151,10 @@ def encode(self, token: str | None) -> torch.Tensor:
150151
self._count_for_unk_token += 1
151152
return torch.tensor([self._unk_token_idx])
152153

154+
if self.eval and str(token) not in self.cache:
155+
self._count_for_unk_token += 1
156+
return torch.tensor([self._unk_token_idx])
157+
153158
if str(token) not in self.cache:
154159
self.cache[str(token)] = len(self.cache)
155160
return torch.tensor([self.cache[str(token)] + self.offset])
@@ -213,6 +218,15 @@ def encode(self, token: str | None) -> torch.Tensor:
213218
Returns:
214219
One-hot encoded tensor of shape (1, encoding_length).
215220
"""
221+
if self.eval:
222+
if token is None or str(token) not in self.cache:
223+
self._count_for_unk_token += 1
224+
return torch.zeros(self.get_encoding_length(), dtype=torch.int64)
225+
index = self.cache[str(token)] + self.offset
226+
return torch.nn.functional.one_hot(
227+
torch.tensor(index), num_classes=self.get_encoding_length()
228+
)
229+
216230
if token not in self.tokens_dict:
217231
self._count_for_unk_token += 1
218232
return torch.zeros(1, self.get_encoding_length(), dtype=torch.int64)

0 commit comments

Comments
 (0)