Skip to content

Commit f4d1d74

Browse files
committed
scope: data preparation code
1 parent 3b17487 commit f4d1d74

File tree

1 file changed

+196
-99
lines changed
  • chebai/preprocessing/datasets/scope

1 file changed

+196
-99
lines changed

chebai/preprocessing/datasets/scope/scope.py

Lines changed: 196 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -51,13 +51,28 @@ class _SCOPeDataExtractor(_DynamicDataset, ABC):
5151
"https://files.rcsb.org/pub/pdb/derived_data/pdb_seqres.txt.gz"
5252
)
5353

54+
SCOPE_HIERARCHY: Dict[str, str] = {
55+
"cl": "class",
56+
"cf": "fold",
57+
"sf": "superfamily",
58+
"fa": "family",
59+
"dm": "protein",
60+
"sp": "species",
61+
"px": "domain",
62+
}
63+
5464
def __init__(
5565
self,
5666
scope_version: float,
5767
scope_version_train: Optional[float] = None,
68+
scope_hierarchy_level: str = "cl",
5869
**kwargs,
5970
):
6071

72+
assert (
73+
scope_hierarchy_level in self.SCOPE_HIERARCHY.keys()
74+
), f"level can contain only one of the following values {self.SCOPE_HIERARCHY.keys()}"
75+
self.scope_hierarchy_level = scope_hierarchy_level
6176
self.scope_version: float = scope_version
6277
self.scope_version_train: float = scope_version_train
6378

@@ -67,7 +82,8 @@ def __init__(
6782
# Instantiate another same class with "scope_version" as "scope_version_train", if train_version is given
6883
# This is to get the data from respective directory related to "scope_version_train"
6984
_init_kwargs = kwargs
70-
_init_kwargs["chebi_version"] = self.scope_version_train
85+
_init_kwargs["scope_version"] = self.scope_version_train
86+
_init_kwargs["scope_hierarchy_level"] = self.scope_hierarchy_level
7187
self._scope_version_train_obj = self.__class__(
7288
**_init_kwargs,
7389
)
@@ -150,18 +166,40 @@ def _download_scope_raw_data(self) -> str:
150166
open(scope_path, "wb").write(r.content)
151167
return "dummy/path"
152168

153-
def _parse_pdb_sequence_file(self) -> Dict[str, Dict[str, str]]:
154-
pdb_chain_seq_mapping: Dict[str, Dict[str, str]] = {}
155-
for record in SeqIO.parse(
156-
os.path.join(self.raw_dir, self.raw_file_names_dict["PDB"]), "fasta"
157-
):
158-
pdb_id, chain = record.id.split("_")
159-
pdb_chain_seq_mapping.setdefault(pdb_id, {})[chain] = str(record.seq)
160-
return pdb_chain_seq_mapping
161-
162169
def _extract_class_hierarchy(self, data_path: str) -> nx.DiGraph:
163170
print("Extracting class hierarchy...")
171+
df_scope = self._get_scope_data()
172+
173+
g = nx.DiGraph()
174+
175+
egdes = []
176+
for _, row in df_scope.iterrows():
177+
g.add_node(row["sunid"], **{"sid": row["sid"], "level": row["level"]})
178+
if row["parent_sunid"] != -1:
179+
egdes.append((row["parent_sunid"], row["sunid"]))
180+
181+
for children_id in row["children_sunids"]:
182+
egdes.append((row["sunid"], children_id))
183+
184+
g.add_edges_from(egdes)
185+
186+
print("Computing transitive closure")
187+
return nx.transitive_closure_dag(g)
188+
189+
def _get_scope_data(self) -> pd.DataFrame:
190+
df_cla = self._get_classification_data()
191+
df_hie = self._get_hierarchy_data()
192+
df_des = self._get_node_description_data()
193+
df_hie_with_cla = pd.merge(df_hie, df_cla, how="left", on="sunid")
194+
df_all = pd.merge(
195+
df_hie_with_cla,
196+
df_des.drop(columns=["sid"], axis=1),
197+
how="left",
198+
on="sunid",
199+
)
200+
return df_all
164201

202+
def _get_classification_data(self) -> pd.DataFrame:
165203
# Load and preprocess CLA file
166204
df_cla = pd.read_csv(
167205
os.path.join(self.raw_dir, self.raw_file_names_dict["CLA"]),
@@ -175,125 +213,166 @@ def _extract_class_hierarchy(self, data_path: str) -> nx.DiGraph:
175213
"description",
176214
"sccs",
177215
"sunid",
178-
"ancestor_nodes",
216+
"hie_levels",
179217
]
180-
df_cla["sunid"] = pd.to_numeric(
181-
df_cla["sunid"], errors="coerce", downcast="integer"
182-
)
183-
df_cla["ancestor_nodes"] = df_cla["ancestor_nodes"].apply(
218+
219+
# Convert to dict - {cl:46456, cf:46457, sf:46458, fa:46459, dm:46460, sp:116748, px:113449}
220+
df_cla["hie_levels"] = df_cla["hie_levels"].apply(
184221
lambda x: {k: int(v) for k, v in (item.split("=") for item in x.split(","))}
185222
)
186-
df_cla.set_index("sunid", inplace=True)
187223

224+
# Split ancestor_nodes into separate columns and assign values
225+
for key in self.SCOPE_HIERARCHY.keys():
226+
df_cla[self.SCOPE_HIERARCHY[key]] = df_cla["hie_levels"].apply(
227+
lambda x: x[key]
228+
)
229+
230+
df_cla["sunid"] = df_cla["sunid"].astype("int64")
231+
232+
return df_cla
233+
234+
def _get_hierarchy_data(self) -> pd.DataFrame:
188235
# Load and preprocess HIE file
189236
df_hie = pd.read_csv(
190237
os.path.join(self.raw_dir, self.raw_file_names_dict["HIE"]),
191238
sep="\t",
192239
header=None,
193240
comment="#",
241+
low_memory=False,
194242
)
195243
df_hie.columns = ["sunid", "parent_sunid", "children_sunids"]
196-
df_hie["sunid"] = pd.to_numeric(
197-
df_hie["sunid"], errors="coerce", downcast="integer"
198-
)
244+
245+
# if not parent id, then insert -1
199246
df_hie["parent_sunid"] = df_hie["parent_sunid"].replace("-", -1).astype(int)
247+
# convert children ids to list of ids
200248
df_hie["children_sunids"] = df_hie["children_sunids"].apply(
201249
lambda x: list(map(int, x.split(","))) if x != "-" else []
202250
)
203251

204-
# Initialize directed graph
205-
g = nx.DiGraph()
252+
# Ensure the 'sunid' column in both DataFrames has the same type
253+
df_hie["sunid"] = df_hie["sunid"].astype("int64")
254+
return df_hie
206255

207-
# Add nodes and edges efficiently
208-
g.add_edges_from(
209-
df_hie[df_hie["parent_sunid"] != -1].apply(
210-
lambda row: (row["parent_sunid"], row["sunid"]), axis=1
211-
)
212-
)
213-
g.add_edges_from(
214-
df_hie.explode("children_sunids")
215-
.dropna()
216-
.apply(lambda row: (row["sunid"], row["children_sunids"]), axis=1)
256+
def _get_node_description_data(self):
257+
# Load and preprocess HIE file
258+
df_des = pd.read_csv(
259+
os.path.join(self.raw_dir, self.raw_file_names_dict["DES"]),
260+
sep="\t",
261+
header=None,
262+
comment="#",
263+
low_memory=False,
217264
)
265+
df_des.columns = ["sunid", "level", "scss", "sid", "description"]
266+
df_des.loc[len(df_des)] = {"sunid": 0, "level": "root"}
218267

219-
pdb_chain_seq_mapping = self._parse_pdb_sequence_file()
268+
# Ensure the 'sunid' column in both DataFrames has the same type
269+
df_des["sunid"] = df_des["sunid"].astype("int64")
270+
return df_des
220271

221-
node_to_pdb_id = df_cla["PDB_ID"].to_dict()
272+
def _graph_to_raw_dataset(self, graph: nx.DiGraph) -> pd.DataFrame:
273+
print(f"Process graph")
222274

223-
for node in g.nodes():
224-
pdb_id = node_to_pdb_id[node]
225-
chain_mapping = pdb_chain_seq_mapping.get(pdb_id, {})
275+
sids = nx.get_node_attributes(graph, "sid")
276+
levels = nx.get_node_attributes(graph, "level")
226277

227-
# Add nodes and edges for chains in the mapping
228-
for chain, sequence in chain_mapping.items():
229-
chain_node = f"{pdb_id}_{chain}"
230-
g.add_node(chain_node, sequence=sequence)
231-
g.add_edge(node, chain_node)
278+
sun_ids = []
279+
sids_list = []
232280

233-
print("Compute transitive closure...")
234-
return nx.transitive_closure_dag(g)
281+
selected_sids_dict = self.select_classes(graph)
235282

236-
def _graph_to_raw_dataset(self, g: nx.DiGraph) -> pd.DataFrame:
237-
"""
238-
Processes a directed acyclic graph (DAG) to create a raw dataset in DataFrame format. The dataset includes
239-
Swiss-Prot protein data and their associations with Gene Ontology (GO) terms.
240-
241-
Note:
242-
- GO classes are used as labels in the dataset. Each GO term is represented as a column, and its value
243-
indicates whether a Swiss-Prot protein is associated with that GO term.
244-
- Swiss-Prot proteins serve as samples. There is no 1-to-1 correspondence between Swiss-Prot proteins
245-
and GO terms.
246-
247-
Data Format: pd.DataFrame
248-
- Column 0 : swiss_id (Identifier for SwissProt protein)
249-
- Column 1 : Accession of the protein
250-
- Column 2 : GO IDs (associated GO terms)
251-
- Column 3 : Sequence of the protein
252-
- Column 4 to Column "n": Each column corresponding to a class with value True/False indicating whether the
253-
protein is associated with this GO term.
283+
for sun_id, level in levels.items():
284+
if level == self.scope_hierarchy_level and sun_id in selected_sids_dict:
285+
sun_ids.append(sun_id)
286+
sids_list.append(sids.get(sun_id))
254287

255-
Args:
256-
g (nx.DiGraph): The class hierarchy graph.
288+
# data_df = pd.DataFrame(OrderedDict(sun_id=sun_ids, sids=sids_list))
289+
df_cla = self._get_classification_data()
290+
target_col_name = self.SCOPE_HIERARCHY[self.scope_hierarchy_level]
291+
df_cla = df_cla[df_cla[target_col_name].isin(sun_ids)]
292+
df_cla = df_cla[["sid", target_col_name]]
257293

258-
Returns:
259-
pd.DataFrame: The raw dataset created from the graph.
260-
"""
261-
print(f"Processing graph")
262-
263-
data_df = self._get_swiss_to_go_mapping()
264-
# add ancestors to go ids
265-
data_df["go_ids"] = data_df["go_ids"].apply(
266-
lambda go_ids: sorted(
267-
set(
268-
itertools.chain.from_iterable(
269-
[
270-
[go_id] + list(g.predecessors(go_id))
271-
for go_id in go_ids
272-
if go_id in g.nodes
273-
]
274-
)
275-
)
276-
)
294+
assert (
295+
len(df_cla) > 1
296+
), "dataframe should have more than one instance for `pd.get_dummies` to work as expected"
297+
df_encoded = pd.get_dummies(
298+
df_cla, columns=[target_col_name], drop_first=False, sparse=True
277299
)
278-
# Initialize the GO term labels/columns to False
279-
selected_classes = self.select_classes(g, data_df=data_df)
280-
new_label_columns = pd.DataFrame(
281-
False, index=data_df.index, columns=selected_classes
300+
301+
pdb_chain_seq_mapping = self._parse_pdb_sequence_file()
302+
303+
sequence_hierarchy_df = pd.DataFrame(
304+
columns=list(df_encoded.columns) + ["sids"]
282305
)
283-
data_df = pd.concat([data_df, new_label_columns], axis=1)
284306

285-
# Set True for the corresponding GO IDs in the DataFrame go labels/columns
286-
for index, row in data_df.iterrows():
287-
for go_id in row["go_ids"]:
288-
if go_id in data_df.columns:
289-
data_df.at[index, go_id] = True
307+
for _, row in df_encoded.iterrows():
308+
assert sum(row.iloc[1:].tolist()) == 1
309+
sid = row["sid"]
310+
# SID: 7-char identifier ("d" + 4-char PDB ID + chain ID ('_' for none, '.' for multiple)
311+
# + domain specifier ('_' if not needed))
312+
assert len(sid) == 7, "sid should have 7 characters"
313+
pdb_id, chain_id = sid[1:5], sid[5]
314+
315+
pdb_to_chain_mapping = pdb_chain_seq_mapping.get(pdb_id, None)
316+
if not pdb_to_chain_mapping:
317+
continue
318+
319+
if chain_id != "_":
320+
chain_sequence = pdb_to_chain_mapping.get(chain_id, None)
321+
if chain_sequence:
322+
self._update_or_add_sequence(
323+
chain_sequence, row, sequence_hierarchy_df
324+
)
325+
326+
else:
327+
# Add nodes and edges for chains in the mapping
328+
for chain, chain_sequence in pdb_to_chain_mapping.items():
329+
self._update_or_add_sequence(
330+
chain_sequence, row, sequence_hierarchy_df
331+
)
332+
333+
sequence_hierarchy_df.drop(columns=["sid"], axis=1, inplace=True)
334+
sequence_hierarchy_df = sequence_hierarchy_df[
335+
["sids"] + [col for col in sequence_hierarchy_df.columns if col != "sids"]
336+
]
290337

291338
# This filters the DataFrame to include only the rows where at least one value in the row from 5th column
292339
# onwards is True/non-zero.
293-
# Quote from DeepGo Paper: `For training and testing, we use proteins which have been annotated with at least
294-
# one GO term from the set of the GO terms for the model`
295-
data_df = data_df[data_df.iloc[:, self._LABELS_START_IDX :].any(axis=1)]
296-
return data_df
340+
sequence_hierarchy_df = sequence_hierarchy_df[
341+
sequence_hierarchy_df.iloc[:, self._LABELS_START_IDX :].any(axis=1)
342+
]
343+
return sequence_hierarchy_df
344+
345+
def _parse_pdb_sequence_file(self) -> Dict[str, Dict[str, str]]:
346+
pdb_chain_seq_mapping: Dict[str, Dict[str, str]] = {}
347+
for record in SeqIO.parse(
348+
os.path.join(self.raw_dir, self.raw_file_names_dict["PDB"]), "fasta"
349+
):
350+
pdb_id, chain = record.id.split("_")
351+
if str(record.seq):
352+
pdb_chain_seq_mapping.setdefault(pdb_id.lower(), {})[chain.lower()] = (
353+
str(record.seq)
354+
)
355+
return pdb_chain_seq_mapping
356+
357+
@staticmethod
358+
def _update_or_add_sequence(sequence, row, sequence_hierarchy_df):
359+
# Check if sequence already exists as an index
360+
# Slice the series starting from column 2
361+
sliced_data = row.iloc[1:] # Slice starting from the second column (index 1)
362+
363+
# Get the column name with the True value
364+
true_column = sliced_data.idxmax() if sliced_data.any() else None
365+
366+
if sequence in sequence_hierarchy_df.index:
367+
# Update encoded columns only if they are True
368+
if row[true_column] is True:
369+
sequence_hierarchy_df.loc[sequence, true_column] = True
370+
sequence_hierarchy_df.loc[sequence, "sids"].append(row["sid"])
371+
else:
372+
# Add new row with sequence as the index and hierarchy data
373+
new_row = row
374+
new_row["sids"] = [row["sid"]]
375+
sequence_hierarchy_df.loc[sequence] = new_row
297376

298377
# ------------------------------ Phase: Setup data -----------------------------------
299378
def _load_dict(self, input_file_path: str) -> Generator[Dict[str, Any], None, None]:
@@ -367,15 +446,33 @@ def raw_file_names_dict(self) -> dict:
367446

368447
class SCOPE(_SCOPeDataExtractor):
369448
READER = ProteinDataReader
449+
THRESHOLD = 1
370450

371451
@property
372452
def _name(self) -> str:
373453
return "test"
374454

375-
def select_classes(self, g: nx.DiGraph, *args, **kwargs) -> List:
376-
pass
455+
def select_classes(self, g: nx.DiGraph, *args, **kwargs) -> Dict:
456+
# Filter nodes and create a dictionary of node and out-degree
457+
sun_ids_dict = {
458+
node: g.out_degree(node) # Store node and its out-degree
459+
for node in g.nodes
460+
if g.out_degree(node) >= self.THRESHOLD
461+
}
462+
463+
# Return a sorted dictionary (by out-degree or node id)
464+
sorted_dict = dict(
465+
sorted(sun_ids_dict.items(), key=lambda item: item[0], reverse=False)
466+
)
467+
468+
filename = "classes.txt"
469+
with open(os.path.join(self.processed_dir_main, filename), "wt") as fout:
470+
fout.writelines(str(sun_id) + "\n" for sun_id in sorted_dict.keys())
471+
472+
return sorted_dict
377473

378474

379475
if __name__ == "__main__":
380476
scope = SCOPE(scope_version=2.08)
381-
scope._parse_pdb_sequence_file()
477+
g = scope._extract_class_hierarchy("d")
478+
scope._graph_to_raw_dataset(g)

0 commit comments

Comments
 (0)