Skip to content

Commit 108d9ca

Browse files
committed
trigrams / n-grams combining several amino acids into one token
1 parent 6f463de commit 108d9ca

File tree

2 files changed

+48
-10
lines changed

2 files changed

+48
-10
lines changed

chebai/preprocessing/datasets/go_uniprot.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -563,6 +563,16 @@ def base_dir(self) -> str:
563563
"""
564564
return os.path.join("data", f"GO_UniProt")
565565

566+
@property
567+
def identifier(self) -> tuple:
568+
"""Identifier for the dataset."""
569+
# overriding identifier instead of reader.name to keep same tokens.txt file, but different processed_dir folder
570+
if not isinstance(self.reader, dr.ProteinDataReader):
571+
raise ValueError("Need Protein DataReader for identifier")
572+
if self.reader.n_gram is not None:
573+
return (f"{self.reader.name()}_{self.reader.n_gram}_gram",)
574+
return (self.reader.name(),)
575+
566576
@property
567577
def raw_file_names_dict(self) -> dict:
568578
"""

chebai/preprocessing/reader.py

Lines changed: 38 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -382,15 +382,24 @@ def name(cls) -> str:
382382
"""
383383
return "protein_token"
384384

385-
def __init__(self, *args, **kwargs):
385+
def __init__(self, *args, n_gram: Optional[int] = None, **kwargs):
386386
"""
387387
Initializes the ProteinDataReader, loading existing tokens from the specified token file.
388388
389389
Args:
390390
*args: Additional positional arguments passed to the base class.
391391
**kwargs: Additional keyword arguments passed to the base class.
392392
"""
393+
if n_gram is not None:
394+
assert (
395+
int(n_gram) >= 2
396+
), "Ngrams must be greater than or equal to 2 if provided."
397+
self.n_gram = int(n_gram)
398+
else:
399+
self.n_gram = None
400+
393401
super().__init__(*args, **kwargs)
402+
394403
# Load the existing tokens from the token file into a cache
395404
with open(self.token_path, "r") as pk:
396405
self.cache = [x.strip() for x in pk]
@@ -405,14 +414,25 @@ def _get_token_index(self, token: str) -> int:
405414
Returns:
406415
int: The index of the token, offset by the predefined EMBEDDING_OFFSET.
407416
"""
408-
if str(token) not in self.AA_LETTER:
409-
raise KeyError(
410-
f"Invalid token '{token}' encountered. "
411-
f"Please ensure that the input only contains valid amino acids "
412-
f"20 Valid natural amino acid notation: {self.AA_LETTER}"
413-
f"Refer to the amino acid sequence details here: "
414-
f"https://en.wikipedia.org/wiki/Protein_primary_structure"
415-
)
417+
error_str = (
418+
f"Please ensure that the input only contains valid amino acids "
419+
f"20 Valid natural amino acid notation: {self.AA_LETTER}"
420+
f"Refer to the amino acid sequence details here: "
421+
f"https://en.wikipedia.org/wiki/Protein_primary_structure"
422+
)
423+
424+
if self.n_gram is None:
425+
# Single-letter amino acid token check
426+
if str(token) not in self.AA_LETTER:
427+
raise KeyError(f"Invalid token '{token}' encountered. " + error_str)
428+
else:
429+
# n-gram token validation, ensure that each component of the n-gram is valid
430+
for aa in token:
431+
if aa not in self.AA_LETTER:
432+
raise KeyError(
433+
f"Invalid token '{token}' encountered as part of n-gram {self.n_gram}. "
434+
+ error_str
435+
)
416436

417437
if str(token) not in self.cache:
418438
self.cache.append(str(token))
@@ -428,7 +448,15 @@ def _read_data(self, raw_data: str) -> List[int]:
428448
Returns:
429449
List[int]: A list of integers representing the indices of the amino acid tokens.
430450
"""
431-
# In the case of protein sequences, each amino acid is typically represented by a single letter.
451+
if self.n_gram is not None:
452+
# Tokenize the sequence into n-grams
453+
tokens = [
454+
raw_data[i : i + self.n_gram]
455+
for i in range(len(raw_data) - self.n_gram + 1)
456+
]
457+
return [self._get_token_index(gram) for gram in tokens]
458+
459+
# If n_gram is None, tokenize the sequence at the amino acid level (single-letter representation)
432460
return [self._get_token_index(aa) for aa in raw_data]
433461

434462
def on_finish(self) -> None:

0 commit comments

Comments
 (0)