1212 load_model_and_alphabet_local ,
1313)
1414
15- from chebai .preprocessing .collate import DefaultCollator , RaggedCollator
16-
17- EMBEDDING_OFFSET = 10
18- PADDING_TOKEN_INDEX = 0
19- MASK_TOKEN_INDEX = 1
20- CLS_TOKEN = 2
21-
22-
23- class DataReader :
24- """
25- Base class for reading and preprocessing data. Turns the raw input data (e.g., a SMILES string) into the model
26- input format (e.g., a list of tokens).
27-
28- Args:
29- collator_kwargs: Optional dictionary of keyword arguments for the collator.
30- token_path: Optional path for the token file.
31- kwargs: Additional keyword arguments (not used).
32- """
33-
34- COLLATOR = DefaultCollator
35-
36- def __init__ (
37- self ,
38- collator_kwargs : Optional [Dict [str , Any ]] = None ,
39- token_path : Optional [str ] = None ,
40- ** kwargs ,
41- ):
42- if collator_kwargs is None :
43- collator_kwargs = dict ()
44- self .collator = self .COLLATOR (** collator_kwargs )
45- self .dirname = os .path .dirname (__file__ )
46- self ._token_path = token_path
47-
48- def _get_raw_data (self , row : Dict [str , Any ]) -> Any :
49- """Get raw data from the row."""
50- return row ["features" ]
51-
52- def _get_raw_label (self , row : Dict [str , Any ]) -> Any :
53- """Get raw label from the row."""
54- return row ["labels" ]
55-
56- def _get_raw_id (self , row : Dict [str , Any ]) -> Any :
57- """Get raw ID from the row."""
58- return row .get ("ident" , row ["features" ])
59-
60- def _get_raw_group (self , row : Dict [str , Any ]) -> Any :
61- """Get raw group from the row."""
62- return row .get ("group" , None )
63-
64- def _get_additional_kwargs (self , row : Dict [str , Any ]) -> Dict [str , Any ]:
65- """Get additional keyword arguments from the row."""
66- return row .get ("additional_kwargs" , dict ())
67-
68- def name (cls ) -> str :
69- """Returns the name of the data reader."""
70- raise NotImplementedError
71-
72- @property
73- def token_path (self ) -> str :
74- """Get token path, create file if it does not exist yet."""
75- if self ._token_path is not None :
76- return self ._token_path
77- token_path = os .path .join (self .dirname , "bin" , self .name (), "tokens.txt" )
78- os .makedirs (os .path .join (self .dirname , "bin" , self .name ()), exist_ok = True )
79- if not os .path .exists (token_path ):
80- with open (token_path , "x" ):
81- pass
82- return token_path
83-
84- def _read_id (self , raw_data : Any ) -> Any :
85- """Read and return ID from raw data."""
86- return raw_data
87-
88- def _read_data (self , raw_data : Any ) -> Any :
89- """Read and return data from raw data."""
90- return raw_data
91-
92- def _read_label (self , raw_label : Any ) -> Any :
93- """Read and return label from raw label."""
94- return raw_label
95-
96- def _read_group (self , raw : Any ) -> Any :
97- """Read and return group from raw group data."""
98- return raw
99-
100- def _read_components (self , row : Dict [str , Any ]) -> Dict [str , Any ]:
101- """Read and return components from the row."""
102- return dict (
103- features = self ._get_raw_data (row ),
104- labels = self ._get_raw_label (row ),
105- ident = self ._get_raw_id (row ),
106- group = self ._get_raw_group (row ),
107- additional_kwargs = self ._get_additional_kwargs (row ),
108- )
109-
110- def to_data (self , row : Dict [str , Any ]) -> Dict [str , Any ]:
111- """Convert raw row data to processed data."""
112- d = self ._read_components (row )
113- return dict (
114- features = self ._read_data (d ["features" ]),
115- labels = self ._read_label (d ["labels" ]),
116- ident = self ._read_id (d ["ident" ]),
117- group = self ._read_group (d ["group" ]),
118- ** d ["additional_kwargs" ],
119- )
120-
121- def on_finish (self ) -> None :
122- """Hook to run at the end of preprocessing."""
123- return
15+ from chebai .preprocessing .collate import RaggedCollator
16+ from chebai .preprocessing .reader import DataReader
12417
12518
12619class ProteinDataReader (DataReader ):
@@ -139,31 +32,15 @@ class ProteinDataReader(DataReader):
13932
14033 COLLATOR = RaggedCollator
14134
35+ # fmt: off
14236 # 21 natural amino acid notation
14337 AA_LETTER = [
144- "A" ,
145- "R" ,
146- "N" ,
147- "D" ,
148- "C" ,
149- "Q" ,
150- "E" ,
151- "G" ,
152- "H" ,
153- "I" ,
154- "L" ,
155- "K" ,
156- "M" ,
157- "F" ,
158- "P" ,
159- "S" ,
160- "T" ,
161- "W" ,
162- "Y" ,
163- "V" ,
38+ "A" , "R" , "N" , "D" , "C" , "Q" , "E" , "G" , "H" , "I" ,
39+ "L" , "K" , "M" , "F" , "P" , "S" , "T" , "W" , "Y" , "V" ,
16440 # https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/aminoacids.py#L3-L5
16541 "X" , # Consider valid in latest paper year 2024 Reference number 3 in go_uniprot.py
16642 ]
43+ # fmt: on
16744
16845 def name (self ) -> str :
16946 """
0 commit comments