Skip to content

Commit 1ae6543

Browse files
committed
Merge branch 'dev' into out_dim_dynamic
2 parents c73c62a + 052677e commit 1ae6543

File tree

10 files changed

+268
-147
lines changed

10 files changed

+268
-147
lines changed

chebai/loss/bce_weighted.py

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -50,32 +50,29 @@ def set_pos_weight(self, input: torch.Tensor) -> None:
5050
and self.data_extractor is not None
5151
and all(
5252
os.path.exists(
53-
os.path.join(self.data_extractor.processed_dir_main, file_name)
53+
os.path.join(self.data_extractor.processed_dir, file_name)
5454
)
55-
for file_name in self.data_extractor.processed_main_file_names
55+
for file_name in self.data_extractor.processed_file_names
5656
)
5757
and self.pos_weight is None
5858
):
5959
print(
6060
f"Computing loss-weights based on v{self.data_extractor.chebi_version} dataset (beta={self.beta})"
6161
)
62-
complete_data = pd.concat(
62+
complete_labels = torch.concat(
6363
[
64-
pd.read_pickle(
65-
open(
66-
os.path.join(
67-
self.data_extractor.processed_dir_main,
68-
file_name,
69-
),
70-
"rb",
71-
)
64+
torch.stack(
65+
[
66+
torch.Tensor(row["labels"])
67+
for row in self.data_extractor.load_processed_data(
68+
filename=file_name
69+
)
70+
]
7271
)
73-
for file_name in self.data_extractor.processed_main_file_names
72+
for file_name in self.data_extractor.processed_file_names
7473
]
7574
)
76-
value_counts = []
77-
for c in complete_data.columns[3:]:
78-
value_counts.append(len([v for v in complete_data[c] if v]))
75+
value_counts = complete_labels.sum(dim=0)
7976
weights = [
8077
(1 - self.beta) / (1 - pow(self.beta, value)) for value in value_counts
8178
]

chebai/models/ffn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def _get_prediction_and_labels(self, data, labels, model_output):
3636
loss_kwargs = data.get("loss_kwargs", dict())
3737
if "non_null_labels" in loss_kwargs:
3838
n = loss_kwargs["non_null_labels"]
39-
d = data[n]
39+
d = d[n]
4040
return torch.sigmoid(d), labels.int() if labels is not None else None
4141

4242
def _process_for_loss(

chebai/preprocessing/datasets/scope/scope.py

Lines changed: 99 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -72,10 +72,12 @@ def __init__(
7272
self,
7373
scope_version: str,
7474
scope_version_train: Optional[str] = None,
75+
max_sequence_len: int = 1000,
7576
**kwargs,
7677
):
7778
self.scope_version: str = scope_version
7879
self.scope_version_train: str = scope_version_train
80+
self.max_sequence_len: int = max_sequence_len
7981

8082
super(_SCOPeDataExtractor, self).__init__(**kwargs)
8183

@@ -195,21 +197,93 @@ def _extract_class_hierarchy(self, data_path: str) -> nx.DiGraph:
195197
"""
196198
print("Extracting class hierarchy...")
197199
df_scope = self._get_scope_data()
200+
pdb_chain_df = self._parse_pdb_sequence_file()
201+
pdb_id_set = set(pdb_chain_df["pdb_id"]) # Search time complexity - O(1)
202+
203+
# Initialize sets and dictionaries for storing edges and attributes
204+
parent_node_edges, node_child_edges = set(), set()
205+
node_attrs = {}
206+
px_level_nodes = set()
207+
sequence_nodes = dict()
208+
px_to_seq_edges = set()
209+
required_graph_nodes = set()
210+
211+
# Create a lookup dictionary for PDB chain sequences
212+
lookup_dict = (
213+
pdb_chain_df.groupby("pdb_id")[["chain_id", "sequence"]]
214+
.apply(lambda x: dict(zip(x["chain_id"], x["sequence"])))
215+
.to_dict()
216+
)
198217

199-
g = nx.DiGraph()
218+
def add_sequence_nodes_edges(chain_sequence, px_sun_id):
219+
"""Adds sequence nodes and edges connecting px-level nodes to sequence nodes."""
220+
if chain_sequence not in sequence_nodes:
221+
sequence_nodes[chain_sequence] = f"seq_{len(sequence_nodes)}"
222+
px_to_seq_edges.add((px_sun_id, sequence_nodes[chain_sequence]))
223+
224+
# Step 1: Build the graph structure and store node attributes
225+
for row in df_scope.itertuples(index=False):
226+
if row.level == "px":
227+
228+
pdb_id, chain_id = row.sid[1:5], row.sid[5]
200229

201-
egdes = []
202-
for _, row in df_scope.iterrows():
203-
g.add_node(row["sunid"], **{"sid": row["sid"], "level": row["level"]})
204-
if row["parent_sunid"] != -1:
205-
egdes.append((row["parent_sunid"], row["sunid"]))
230+
if pdb_id not in pdb_id_set or chain_id == "_":
231+
# Don't add domain level nodes that don't have pdb_id in pdb_sequences.txt file
232+
# Also chain_id with "_" which corresponds to no chain
233+
continue
234+
px_level_nodes.add(row.sunid)
206235

207-
for children_id in row["children_sunids"]:
208-
egdes.append((row["sunid"], children_id))
236+
# Add edges between px-level nodes and sequence nodes
237+
if chain_id != ".":
238+
if chain_id not in lookup_dict[pdb_id]:
239+
continue
240+
add_sequence_nodes_edges(lookup_dict[pdb_id][chain_id], row.sunid)
241+
else:
242+
# If chain_id is '.', connect all chains of this PDB ID
243+
for chain, chain_sequence in lookup_dict[pdb_id].items():
244+
add_sequence_nodes_edges(chain_sequence, row.sunid)
245+
else:
246+
required_graph_nodes.add(row.sunid)
209247

210-
g.add_edges_from(egdes)
248+
node_attrs[row.sunid] = {"sid": row.sid, "level": row.level}
211249

212-
print("Computing transitive closure")
250+
if row.parent_sunid != -1:
251+
parent_node_edges.add((row.parent_sunid, row.sunid))
252+
253+
for child_id in row.children_sunids:
254+
node_child_edges.add((row.sunid, child_id))
255+
256+
del df_scope, pdb_chain_df, pdb_id_set
257+
258+
g = nx.DiGraph()
259+
g.add_nodes_from(node_attrs.items())
260+
# Note - `add_edges` internally create a node, if a node doesn't exist already
261+
g.add_edges_from({(p, c) for p, c in parent_node_edges if p in node_attrs})
262+
g.add_edges_from({(p, c) for p, c in node_child_edges if c in node_attrs})
263+
264+
seq_nodes = set(sequence_nodes.values())
265+
g.add_nodes_from([(seq_id, {"level": "sequence"}) for seq_id in seq_nodes])
266+
g.add_edges_from(
267+
{
268+
(px_node, seq_node)
269+
for px_node, seq_node in px_to_seq_edges
270+
if px_node in node_attrs and seq_node in seq_nodes
271+
}
272+
)
273+
274+
# Step 2: Count sequence successors for required graph nodes only
275+
for node in required_graph_nodes:
276+
num_seq_successors = sum(
277+
g.nodes[child]["level"] == "sequence"
278+
for child in nx.descendants(g, node)
279+
)
280+
g.nodes[node]["num_seq_successors"] = num_seq_successors
281+
282+
# Step 3: Remove nodes which are not required before computing transitive closure for better efficiency
283+
g.remove_nodes_from(px_level_nodes | seq_nodes)
284+
285+
print("Computing Transitive Closure.........")
286+
# Transitive closure is not needed in `select_classes` method but is required in _SCOPeOverXPartial
213287
return nx.transitive_closure_dag(g)
214288

215289
def _get_scope_data(self) -> pd.DataFrame:
@@ -388,7 +462,8 @@ def _graph_to_raw_dataset(self, graph: nx.DiGraph) -> pd.DataFrame:
388462

389463
encoded_target_columns = []
390464
for level in hierarchy_levels:
391-
encoded_target_columns.extend(lvl_to_target_cols_mapping[level])
465+
if level in lvl_to_target_cols_mapping:
466+
encoded_target_columns.extend(lvl_to_target_cols_mapping[level])
392467

393468
print(
394469
f"{len(encoded_target_columns)} labels has been selected for specified threshold, "
@@ -471,12 +546,12 @@ def _parse_pdb_sequence_file(self) -> pd.DataFrame:
471546
for record in SeqIO.parse(
472547
os.path.join(self.scope_root_dir, self.raw_file_names_dict["PDB"]), "fasta"
473548
):
549+
550+
if not record.seq or len(record.seq) > self.max_sequence_len:
551+
continue
552+
474553
pdb_id, chain = record.id.split("_")
475-
sequence = (
476-
re.sub(f"[^{valid_amino_acids}]", "X", str(record.seq))
477-
if record.seq
478-
else ""
479-
)
554+
sequence = re.sub(f"[^{valid_amino_acids}]", "X", str(record.seq))
480555

481556
# Store as a dictionary entry (list of dicts -> DataFrame later)
482557
records.append(
@@ -777,12 +852,15 @@ def select_classes(self, g: nx.DiGraph, *args, **kwargs) -> Dict[str, List[int]]
777852
"""
778853
selected_sunids_for_level = {}
779854
for node, attr_dict in g.nodes(data=True):
780-
if g.out_degree(node) >= self.THRESHOLD:
855+
if attr_dict["level"] in {"root", "px", "sequence"}:
856+
# Skip nodes with level "root", "px", or "sequence"
857+
continue
858+
859+
# Check if the number of "sequence"-level successors meets or exceeds the threshold
860+
if g.nodes[node]["num_seq_successors"] >= self.THRESHOLD:
781861
selected_sunids_for_level.setdefault(attr_dict["level"], []).append(
782862
node
783863
)
784-
# Remove root node, as it will True for all instances
785-
selected_sunids_for_level.pop("root", None)
786864
return selected_sunids_for_level
787865

788866

@@ -876,7 +954,8 @@ class SCOPeOverPartial2000(_SCOPeOverXPartial):
876954

877955

878956
if __name__ == "__main__":
879-
scope = SCOPeOver2000(scope_version="2.08")
957+
scope = SCOPeOver50(scope_version="2.08")
958+
880959
# g = scope._extract_class_hierarchy("dummy/path")
881960
# # Save graph
882961
# import pickle

chebai/preprocessing/migration/deep_go/migrate_deep_go_2_data.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -213,10 +213,8 @@ def _extract_required_data_from_splits(self) -> pd.DataFrame:
213213
"proteins",
214214
"accessions",
215215
"sequences",
216-
# https://github.com/bio-ontology-research-group/deepgo2/blob/main/gendata/uni2pandas.py#L45-L58
217-
"exp_annotations", # Directly associated GO ids
218216
# https://github.com/bio-ontology-research-group/deepgo2/blob/main/gendata/uni2pandas.py#L60-L69
219-
"prop_annotations", # Transitively associated GO ids
217+
"prop_annotations", # Direct and Transitively associated GO ids
220218
"esm2",
221219
]
222220

@@ -228,10 +226,8 @@ def _extract_required_data_from_splits(self) -> pd.DataFrame:
228226
],
229227
ignore_index=True,
230228
)
231-
new_df["go_ids"] = new_df.apply(
232-
lambda row: self.extract_go_id(row["exp_annotations"])
233-
+ self.extract_go_id(row["prop_annotations"]),
234-
axis=1,
229+
new_df["go_ids"] = new_df["prop_annotations"].apply(
230+
lambda x: self.extract_go_id(x)
235231
)
236232

237233
data_df = pd.DataFrame(
@@ -270,7 +266,7 @@ def _generate_labels(self, data_df: pd.DataFrame) -> pd.DataFrame:
270266
"""
271267
print("Generating labels based on terms.pkl file.......")
272268
parsed_go_ids: pd.Series = self._terms_df["gos"].apply(
273-
lambda gos: DeepGO2MigratedData._parse_go_id(gos)
269+
DeepGO2MigratedData._parse_go_id
274270
)
275271
all_go_ids_list = parsed_go_ids.values.tolist()
276272
self._classes = all_go_ids_list

chebai/result/classification.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def print_metrics(
7878
print(f"Micro-Recall: {recall_micro(preds, labels):3f}")
7979
if markdown_output:
8080
print(
81-
f"| Model | Macro-F1 | Micro-F1 | Macro-Precision | Micro-Precision | Macro-Recall | Micro-Recall | Balanced Accuracy"
81+
f"| Model | Macro-F1 | Micro-F1 | Macro-Precision | Micro-Precision | Macro-Recall | Micro-Recall | Balanced Accuracy |"
8282
)
8383
print(f"| --- | --- | --- | --- | --- | --- | --- | --- |")
8484
print(

chebai/result/utils.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -156,11 +156,12 @@ def evaluate_model(
156156
return test_preds, test_labels
157157
return test_preds, None
158158
elif len(preds_list) < 0:
159-
torch.save(
160-
_concat_tuple(preds_list),
161-
os.path.join(buffer_dir, f"preds{save_ind:03d}.pt"),
162-
)
163-
if labels_list[0] is not None:
159+
if len(preds_list) > 0 and preds_list[0] is not None:
160+
torch.save(
161+
_concat_tuple(preds_list),
162+
os.path.join(buffer_dir, f"preds{save_ind:03d}.pt"),
163+
)
164+
if len(labels_list) > 0 and labels_list[0] is not None:
164165
torch.save(
165166
_concat_tuple(labels_list),
166167
os.path.join(buffer_dir, f"labels{save_ind:03d}.pt"),

configs/data/scope/scope50.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
class_path: chebai.preprocessing.datasets.scope.scope.SCOPeOver50
22
init_args:
3-
scope_version: 2.08
3+
scope_version: "2.08"

configs/model/electra.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ init_args:
33
optimizer_kwargs:
44
lr: 1e-3
55
config:
6-
vocab_size: 8500
6+
vocab_size: 1400
77
max_position_embeddings: 1800
88
num_attention_heads: 8
99
num_hidden_layers: 6

tutorials/data_exploration_go.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@
7070
}
7171
},
7272
"outputs": [],
73-
"source": "from chebai.preprocessing.datasets.deepGO.go_uniprot import GOUniProtOver250"
73+
"source": "from chebai.preprocessing.datasets.go_uniprot import GOUniProtOver250"
7474
},
7575
{
7676
"cell_type": "code",

0 commit comments

Comments
 (0)