Skip to content

Commit 355085f

Browse files
committed
reformat with black
1 parent 57312a0 commit 355085f

File tree

3 files changed

+10
-6
lines changed

3 files changed

+10
-6
lines changed

chebai_graph/preprocessing/datasets/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
ChEBI50_WFGE_WGN_GraphProp,
1313
ChEBI50GraphData,
1414
ChEBI50GraphProperties,
15-
ChEBI100GraphProperties
15+
ChEBI100GraphProperties,
1616
)
1717
from .pubchem import PubChemGraphProperties
1818

chebai_graph/preprocessing/datasets/chebi.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,9 @@ def _merge_props_into_base(self, row: pd.Series) -> GeomData:
282282
molecule_attr=molecule_attr,
283283
)
284284

285-
def load_processed_data(self, kind: Optional[str] = None, filename: Optional[str] = None) -> list[dict]:
285+
def load_processed_data(
286+
self, kind: Optional[str] = None, filename: Optional[str] = None
287+
) -> list[dict]:
286288
"""
287289
Load dataset and merge cached properties into base features.
288290
@@ -380,7 +382,9 @@ def __init__(self, properties=None, transform=None, **kwargs):
380382
f"Data module uses these properties (ordered): {', '.join([str(p) for p in self.properties])}",
381383
)
382384

383-
def load_processed_data(self, kind: Optional[str] = None, filename: Optional[str] = None) -> list[dict]:
385+
def load_processed_data(
386+
self, kind: Optional[str] = None, filename: Optional[str] = None
387+
) -> list[dict]:
384388
"""
385389
Load dataset and merge cached properties into base features.
386390

chebai_graph/preprocessing/property_encoder.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ class PropertyEncoder(abc.ABC):
1919
def __init__(self, property, eval=False, **kwargs) -> None:
2020
self.property = property
2121
self._encoding_length: int = 1
22-
self.eval = eval # if True, do not update cache (for index encoder)
22+
self.eval = eval # if True, do not update cache (for index encoder)
2323

2424
@property
2525
def name(self) -> str:
@@ -150,11 +150,11 @@ def encode(self, token: str | None) -> torch.Tensor:
150150
if token is None:
151151
self._count_for_unk_token += 1
152152
return torch.tensor([self._unk_token_idx])
153-
153+
154154
if self.eval and str(token) not in self.cache:
155155
self._count_for_unk_token += 1
156156
return torch.tensor([self._unk_token_idx])
157-
157+
158158
if str(token) not in self.cache:
159159
self.cache[str(token)] = len(self.cache)
160160
return torch.tensor([self.cache[str(token)] + self.offset])

0 commit comments

Comments
 (0)