Skip to content

Commit aa4545a

Browse files
committed
default max len of go, scope should be same
- needed for pretraining (max_position_embeddings)
1 parent f73161f commit aa4545a

File tree

3 files changed

+8
-11
lines changed

3 files changed

+8
-11
lines changed

chebai_proteins/preprocessing/datasets/deepGO/go_uniprot.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -102,17 +102,18 @@ class _GOUniProtDataExtractor(_DynamicDataset, ABC):
102102

103103
# Gene Ontology (GO) has three major branches, one for biological processes (BP), molecular functions (MF) and
104104
# cellular components (CC). The value "all" will take data related to all three branches into account.
105+
# TODO: should we be really allowing all branches for single dataset?
105106
_ALL_GO_BRANCHES: str = "all"
106107
_GO_BRANCH_NAMESPACE: Dict[str, str] = {
107108
"BP": "biological_process",
108109
"MF": "molecular_function",
109110
"CC": "cellular_component",
110111
}
111112

112-
def __init__(self, **kwargs):
113-
self.go_branch: str = self._get_go_branch(**kwargs)
113+
def __init__(self, go_branch: str, max_sequence_len: int = 1002, **kwargs):
114+
self.go_branch: str = self._get_go_branch(go_branch)
114115

115-
self.max_sequence_length: int = int(kwargs.get("max_sequence_length", 1002))
116+
self.max_sequence_length: int = int(max_sequence_len)
116117
assert (
117118
self.max_sequence_length >= 1
118119
), "Max sequence length should be greater than or equal to 1."
@@ -126,7 +127,7 @@ def __init__(self, **kwargs):
126127
)
127128

128129
@classmethod
129-
def _get_go_branch(cls, **kwargs) -> str:
130+
def _get_go_branch(cls, go_branch_value: str, **kwargs) -> str:
130131
"""
131132
Retrieves the Gene Ontology (GO) branch based on provided keyword arguments.
132133
This method checks if a valid GO branch value is provided in the keyword arguments.
@@ -141,7 +142,6 @@ def _get_go_branch(cls, **kwargs) -> str:
141142
ValueError: If the provided 'go_branch' value is not in the allowed list of values.
142143
"""
143144

144-
go_branch_value = kwargs.get("go_branch", cls._ALL_GO_BRANCHES)
145145
allowed_values = list(cls._GO_BRANCH_NAMESPACE.keys()) + [cls._ALL_GO_BRANCHES]
146146
if go_branch_value not in allowed_values:
147147
raise ValueError(

chebai_proteins/preprocessing/datasets/scope/scope.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -72,12 +72,12 @@ def __init__(
7272
self,
7373
scope_version: str,
7474
scope_version_train: Optional[str] = None,
75-
max_sequence_len: int = 1000,
75+
max_sequence_len: int = 1002,
7676
**kwargs,
7777
):
7878
self.scope_version: str = scope_version
7979
self.scope_version_train: str = scope_version_train
80-
self.max_sequence_len: int = max_sequence_len
80+
self.max_sequence_len: int = int(max_sequence_len)
8181

8282
super(_SCOPeDataExtractor, self).__init__(**kwargs)
8383

@@ -224,7 +224,6 @@ def add_sequence_nodes_edges(chain_sequence, px_sun_id):
224224
# Step 1: Build the graph structure and store node attributes
225225
for row in df_scope.itertuples(index=False):
226226
if row.level == "px":
227-
228227
pdb_id, chain_id = row.sid[1:5], row.sid[5]
229228

230229
if pdb_id not in pdb_id_set or chain_id == "_":
@@ -546,7 +545,6 @@ def _parse_pdb_sequence_file(self) -> pd.DataFrame:
546545
for record in SeqIO.parse(
547546
os.path.join(self.scope_root_dir, self.raw_file_names_dict["PDB"]), "fasta"
548547
):
549-
550548
if not record.seq or len(record.seq) > self.max_sequence_len:
551549
continue
552550

@@ -934,7 +932,6 @@ class SCOPeOver2000(_SCOPeOverX):
934932

935933

936934
class SCOPeOver50(_SCOPeOverX):
937-
938935
THRESHOLD = 50
939936

940937

configs/model/electra.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ init_args:
44
lr: 1e-3
55
config:
66
vocab_size: 31
7-
max_position_embeddings: 1000
7+
max_position_embeddings: 1002
88
num_attention_heads: 8
99
num_hidden_layers: 6
1010
type_vocab_size: 1

0 commit comments

Comments
 (0)