Skip to content

Commit 66dd504

Browse files
committed
Update protein_pretraining.py
1 parent fc50c31 commit 66dd504

File tree

1 file changed

+17
-1
lines changed

1 file changed

+17
-1
lines changed

chebai/preprocessing/datasets/protein_pretraining.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,20 @@ def __init__(self, **kwargs):
4040
"""
4141
self._go_uniprot_extractor = GOUniProtOver250()
4242
assert self._go_uniprot_extractor.go_branch == GOUniProtOver250._ALL_GO_BRANCHES
43+
44+
self.max_sequence_length: int = int(kwargs.get("max_sequence_length", 1002))
45+
assert (
46+
self.max_sequence_length >= 1
47+
), "Max sequence length should be greater than or equal to 1."
48+
4349
super(_ProteinPretrainingData, self).__init__(**kwargs)
4450

51+
if self.reader.n_gram is not None:
52+
assert self.max_sequence_length >= self.reader.n_gram, (
53+
f"max_sequence_length ({self.max_sequence_length}) must be greater than "
54+
f"or equal to n_gram ({self.reader.n_gram})."
55+
)
56+
4557
# ------------------------------ Phase: Prepare data -----------------------------------
4658
def prepare_data(self, *args: Any, **kwargs: Any) -> None:
4759
"""
@@ -120,6 +132,10 @@ def _parse_protein_data_for_pretraining(self) -> pd.DataFrame:
120132
# Consider protein with only sequence representation
121133
continue
122134

135+
if len(record.sequence) > self.max_sequence_length:
136+
# Consider protein with only sequence length not greater than max seq. length
137+
continue
138+
123139
if any(aa in AMBIGUOUS_AMINO_ACIDS for aa in record.sequence):
124140
# Skip proteins with ambiguous amino acid codes
125141
continue
@@ -260,4 +276,4 @@ def _name(self) -> str:
260276
Returns:
261277
str: A string identifier, "SwissProteinPretrain", representing the name of this data module.
262278
"""
263-
return "SwissProteinPretrain"
279+
return f"Swiss_{self.max_sequence_length}"

0 commit comments

Comments
 (0)