Skip to content

Commit 6d7e6bd

Browse files
committed
update readers for proteins
1 parent 9120538 commit 6d7e6bd

File tree

1 file changed

+6
-129
lines changed

1 file changed

+6
-129
lines changed

chebai/preprocessing/reader.py

Lines changed: 6 additions & 129 deletions
Original file line numberDiff line numberDiff line change
@@ -12,115 +12,8 @@
1212
load_model_and_alphabet_local,
1313
)
1414

15-
from chebai.preprocessing.collate import DefaultCollator, RaggedCollator
16-
17-
EMBEDDING_OFFSET = 10
18-
PADDING_TOKEN_INDEX = 0
19-
MASK_TOKEN_INDEX = 1
20-
CLS_TOKEN = 2
21-
22-
23-
class DataReader:
24-
"""
25-
Base class for reading and preprocessing data. Turns the raw input data (e.g., a SMILES string) into the model
26-
input format (e.g., a list of tokens).
27-
28-
Args:
29-
collator_kwargs: Optional dictionary of keyword arguments for the collator.
30-
token_path: Optional path for the token file.
31-
kwargs: Additional keyword arguments (not used).
32-
"""
33-
34-
COLLATOR = DefaultCollator
35-
36-
def __init__(
37-
self,
38-
collator_kwargs: Optional[Dict[str, Any]] = None,
39-
token_path: Optional[str] = None,
40-
**kwargs,
41-
):
42-
if collator_kwargs is None:
43-
collator_kwargs = dict()
44-
self.collator = self.COLLATOR(**collator_kwargs)
45-
self.dirname = os.path.dirname(__file__)
46-
self._token_path = token_path
47-
48-
def _get_raw_data(self, row: Dict[str, Any]) -> Any:
49-
"""Get raw data from the row."""
50-
return row["features"]
51-
52-
def _get_raw_label(self, row: Dict[str, Any]) -> Any:
53-
"""Get raw label from the row."""
54-
return row["labels"]
55-
56-
def _get_raw_id(self, row: Dict[str, Any]) -> Any:
57-
"""Get raw ID from the row."""
58-
return row.get("ident", row["features"])
59-
60-
def _get_raw_group(self, row: Dict[str, Any]) -> Any:
61-
"""Get raw group from the row."""
62-
return row.get("group", None)
63-
64-
def _get_additional_kwargs(self, row: Dict[str, Any]) -> Dict[str, Any]:
65-
"""Get additional keyword arguments from the row."""
66-
return row.get("additional_kwargs", dict())
67-
68-
def name(cls) -> str:
69-
"""Returns the name of the data reader."""
70-
raise NotImplementedError
71-
72-
@property
73-
def token_path(self) -> str:
74-
"""Get token path, create file if it does not exist yet."""
75-
if self._token_path is not None:
76-
return self._token_path
77-
token_path = os.path.join(self.dirname, "bin", self.name(), "tokens.txt")
78-
os.makedirs(os.path.join(self.dirname, "bin", self.name()), exist_ok=True)
79-
if not os.path.exists(token_path):
80-
with open(token_path, "x"):
81-
pass
82-
return token_path
83-
84-
def _read_id(self, raw_data: Any) -> Any:
85-
"""Read and return ID from raw data."""
86-
return raw_data
87-
88-
def _read_data(self, raw_data: Any) -> Any:
89-
"""Read and return data from raw data."""
90-
return raw_data
91-
92-
def _read_label(self, raw_label: Any) -> Any:
93-
"""Read and return label from raw label."""
94-
return raw_label
95-
96-
def _read_group(self, raw: Any) -> Any:
97-
"""Read and return group from raw group data."""
98-
return raw
99-
100-
def _read_components(self, row: Dict[str, Any]) -> Dict[str, Any]:
101-
"""Read and return components from the row."""
102-
return dict(
103-
features=self._get_raw_data(row),
104-
labels=self._get_raw_label(row),
105-
ident=self._get_raw_id(row),
106-
group=self._get_raw_group(row),
107-
additional_kwargs=self._get_additional_kwargs(row),
108-
)
109-
110-
def to_data(self, row: Dict[str, Any]) -> Dict[str, Any]:
111-
"""Convert raw row data to processed data."""
112-
d = self._read_components(row)
113-
return dict(
114-
features=self._read_data(d["features"]),
115-
labels=self._read_label(d["labels"]),
116-
ident=self._read_id(d["ident"]),
117-
group=self._read_group(d["group"]),
118-
**d["additional_kwargs"],
119-
)
120-
121-
def on_finish(self) -> None:
122-
"""Hook to run at the end of preprocessing."""
123-
return
15+
from chebai.preprocessing.collate import RaggedCollator
16+
from chebai.preprocessing.reader import DataReader
12417

12518

12619
class ProteinDataReader(DataReader):
@@ -139,31 +32,15 @@ class ProteinDataReader(DataReader):
13932

14033
COLLATOR = RaggedCollator
14134

35+
# fmt: off
14236
# 21 natural amino acid notation
14337
AA_LETTER = [
144-
"A",
145-
"R",
146-
"N",
147-
"D",
148-
"C",
149-
"Q",
150-
"E",
151-
"G",
152-
"H",
153-
"I",
154-
"L",
155-
"K",
156-
"M",
157-
"F",
158-
"P",
159-
"S",
160-
"T",
161-
"W",
162-
"Y",
163-
"V",
38+
"A", "R", "N", "D", "C", "Q", "E", "G", "H", "I",
39+
"L", "K", "M", "F", "P", "S", "T", "W", "Y", "V",
16440
# https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/aminoacids.py#L3-L5
16541
"X", # Consider valid in latest paper year 2024 Reference number 3 in go_uniprot.py
16642
]
43+
# fmt: on
16744

16845
def name(self) -> str:
16946
"""

0 commit comments

Comments
 (0)