Skip to content

Commit f7c4eb7

Browse files
committed
remove non-positive filter (allow samples with no positive labels)
1 parent 0a00d11 commit f7c4eb7

File tree

1 file changed

+12
-6
lines changed

1 file changed

+12
-6
lines changed

chebai/preprocessing/datasets/chebi.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -301,9 +301,7 @@ def _graph_to_raw_dataset(self, g: nx.DiGraph) -> pd.DataFrame:
301301
data = pd.DataFrame(data)
302302
data = data[~data["SMILES"].isnull()]
303303
data = data[[name not in CHEBI_BLACKLIST for name, _ in data.iterrows()]]
304-
# This filters the DataFrame to include only the rows where at least one value in the row from 4th column
305-
# onwards is True/non-zero.
306-
data = data[data.iloc[:, self._LABELS_START_IDX :].any(axis=1)]
304+
307305
return data
308306

309307
# ------------------------------ Phase: Setup data -----------------------------------
@@ -892,7 +890,7 @@ def term_callback(doc: fastobo.term.TermFrame) -> Union[Dict, bool]:
892890

893891

894892
atom_index = (
895-
"\*",
893+
r"\*",
896894
"H",
897895
"He",
898896
"Li",
@@ -1525,5 +1523,13 @@ def term_callback(doc: fastobo.term.TermFrame) -> Union[Dict, bool]:
15251523
JCI_500_COLUMNS_INT = [int(n.split(":")[-1]) for n in JCI_500_COLUMNS]
15261524

15271525
if __name__ == "__main__":
1528-
# get arguments from command line
1529-
pass
1526+
data_module_05 = ChEBIOver50Partial(
1527+
chebi_version=241,
1528+
splits_file_path=os.path.join(
1529+
"data", "chebi_v241", "ChEBI50", "splits_80_10_10.csv"
1530+
),
1531+
top_class_id=22712,
1532+
external_data_ratio=0.5,
1533+
)
1534+
data_module_05.prepare_data()
1535+
data_module_05.setup()

0 commit comments

Comments
 (0)