Skip to content

Commit 43d4550

Browse files
committed
scope: remove domain level from one hot encoding
1 parent 431da47 commit 43d4550

File tree

1 file changed

+17
-6
lines changed
  • chebai/preprocessing/datasets/scope

1 file changed

+17
-6
lines changed

chebai/preprocessing/datasets/scope/scope.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -279,31 +279,42 @@ def _graph_to_raw_dataset(self, graph: nx.DiGraph) -> pd.DataFrame:
279279
sun_ids.pop("root", None)
280280

281281
# data_df = pd.DataFrame(OrderedDict(sun_id=sun_ids, sids=sids_list))
282+
if not sun_ids:
283+
raise RuntimeError("No sunid selected.")
284+
282285
df_cla = self._get_classification_data()
286+
hierarchy_levels = list(self.SCOPE_HIERARCHY.values())
287+
hierarchy_levels.remove("domain")
288+
289+
df_cla = df_cla[["sid", "sunid"] + hierarchy_levels]
283290

284291
for level, selected_sun_ids in sun_ids.items():
285-
df_cla = df_cla[df_cla[self.SCOPE_HIERARCHY[level]].isin(selected_sun_ids)]
292+
if selected_sun_ids:
293+
df_cla = df_cla[
294+
df_cla[self.SCOPE_HIERARCHY[level]].isin(selected_sun_ids)
295+
]
286296

287297
assert (
288298
len(df_cla) > 1
289299
), "dataframe should have more than one instance for `pd.get_dummies` to work as expected"
300+
290301
df_encoded = pd.get_dummies(
291302
df_cla,
292-
columns=list(self.SCOPE_HIERARCHY.values()),
303+
columns=hierarchy_levels,
293304
drop_first=False,
294305
sparse=True,
295306
)
296307

297308
pdb_chain_seq_mapping = self._parse_pdb_sequence_file()
298309

299310
encoded_target_cols = {}
300-
for col in self.SCOPE_HIERARCHY.values():
311+
for col in hierarchy_levels:
301312
encoded_target_cols[col] = [
302313
t_col for t_col in df_encoded.columns if t_col.startswith(col)
303314
]
304315

305316
encoded_target_columns = []
306-
for level in self.SCOPE_HIERARCHY.values():
317+
for level in hierarchy_levels:
307318
encoded_target_columns.extend(encoded_target_cols[level])
308319

309320
sequence_hierarchy_df = pd.DataFrame(columns=["sids"] + encoded_target_columns)
@@ -333,8 +344,8 @@ def _graph_to_raw_dataset(self, graph: nx.DiGraph) -> pd.DataFrame:
333344
chain_sequence, row, sequence_hierarchy_df, encoded_target_cols
334345
)
335346

336-
sequence_hierarchy_df.drop(columns=["sid"], axis=1, inplace=True)
337347
sequence_hierarchy_df.reset_index(inplace=True)
348+
sequence_hierarchy_df.rename(columns={"index": "sequence"}, inplace=True)
338349
sequence_hierarchy_df["id"] = range(1, len(sequence_hierarchy_df) + 1)
339350

340351
sequence_hierarchy_df = sequence_hierarchy_df[
@@ -457,7 +468,7 @@ def raw_file_names_dict(self) -> dict:
457468

458469
class SCOPE(_SCOPeDataExtractor):
459470
READER = ProteinDataReader
460-
THRESHOLD = 10000
471+
THRESHOLD = 2143
461472

462473
@property
463474
def _name(self) -> str:

0 commit comments

Comments
 (0)