Skip to content

Commit af54954

Browse files
committed
Merge branch 'dev' into protein_prediction
2 parents ca5461f + 7480783 commit af54954

File tree

5 files changed

+72
-64
lines changed

5 files changed

+72
-64
lines changed

chebai/loss/bce_weighted.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,10 @@ def set_pos_weight(self, input: torch.Tensor) -> None:
4343
self.beta is not None
4444
and self.data_extractor is not None
4545
and all(
46-
os.path.exists(os.path.join(self.data_extractor.raw_dir, raw_file))
47-
for raw_file in self.data_extractor.raw_file_names
46+
os.path.exists(
47+
os.path.join(self.data_extractor.processed_dir_main, file_name)
48+
)
49+
for file_name in self.data_extractor.processed_main_file_names
4850
)
4951
and self.pos_weight is None
5052
):
@@ -53,13 +55,13 @@ def set_pos_weight(self, input: torch.Tensor) -> None:
5355
pd.read_pickle(
5456
open(
5557
os.path.join(
56-
self.data_extractor.raw_dir,
57-
raw_file_name,
58+
self.data_extractor.processed_dir_main,
59+
file_name,
5860
),
5961
"rb",
6062
)
6163
)
62-
for raw_file_name in self.data_extractor.raw_file_names
64+
for file_name in self.data_extractor.processed_main_file_names
6365
]
6466
)
6567
value_counts = []

chebai/preprocessing/datasets/base.py

Lines changed: 45 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -134,10 +134,15 @@ def base_dir(self) -> str:
134134
return self._base_dir
135135
return os.path.join("data", self._name)
136136

137+
@property
138+
def processed_dir_main(self) -> str:
139+
"""Name of the directory where processed (but not tokenized) data is stored."""
140+
return os.path.join(self.base_dir, "processed")
141+
137142
@property
138143
def processed_dir(self) -> str:
139-
"""Name of the directory where the processed data is stored."""
140-
return os.path.join(self.base_dir, "processed", *self.identifier)
144+
"""Name of the directory where the processed and tokenized data is stored."""
145+
return os.path.join(self.processed_dir_main, *self.identifier)
141146

142147
@property
143148
def raw_dir(self) -> str:
@@ -394,45 +399,61 @@ def setup_processed(self):
394399
raise NotImplementedError
395400

396401
@property
397-
def processed_file_names(self) -> List[str]:
402+
def processed_main_file_names_dict(self) -> dict:
398403
"""
399-
Returns the list of processed file names.
400-
401-
This property should be implemented by subclasses to provide the list of processed file names.
404+
Returns a dictionary mapping processed data file names.
402405
403406
Returns:
404-
List[str]: The list of processed file names.
407+
dict: A dictionary mapping dataset key to their respective file names.
408+
For example, {"data": "data.pkl"}.
405409
"""
406410
raise NotImplementedError
407411

408412
@property
409-
def raw_file_names(self) -> List[str]:
413+
def processed_main_file_names(self) -> List[str]:
410414
"""
411-
Returns the list of raw file names.
415+
Returns a list of file names for processed data (before tokenization).
416+
417+
Returns:
418+
List[str]: A list of file names corresponding to the processed data.
419+
"""
420+
return list(self.processed_main_file_names_dict.values())
412421

413-
This property should be implemented by subclasses to provide the list of raw file names.
422+
@property
423+
def processed_file_names_dict(self) -> dict:
424+
"""
425+
Returns a dictionary for the processed and tokenized data files.
414426
415427
Returns:
416-
List[str]: The list of raw file names.
428+
dict: A dictionary mapping dataset keys to their respective file names.
429+
For example, {"data": "data.pt"}.
417430
"""
418431
raise NotImplementedError
419432

420433
@property
421-
def processed_file_names_dict(self) -> dict:
434+
def processed_file_names(self) -> List[str]:
422435
"""
423-
Returns the dictionary of processed file names.
436+
Returns a list of file names for processed data.
424437
425-
This property should be implemented by subclasses to provide the dictionary of processed file names.
438+
Returns:
439+
List[str]: A list of file names corresponding to the processed data.
440+
"""
441+
return list(self.processed_file_names_dict.values())
442+
443+
@property
444+
def raw_file_names(self) -> List[str]:
445+
"""
446+
Returns the list of raw file names.
426447
427448
Returns:
428-
dict: The dictionary of processed file names.
449+
List[str]: The list of raw file names.
429450
"""
430-
raise NotImplementedError
451+
return list(self.raw_file_names_dict.values())
431452

432453
@property
433454
def raw_file_names_dict(self) -> dict:
434455
"""
435-
Returns the dictionary of raw file names.
456+
Returns the dictionary of raw file names (i.e., files that are directly obtained from an external source).
436457
437458
This property should be implemented by subclasses to provide the dictionary of raw file names.
438459
@@ -705,7 +726,7 @@ def prepare_data(self, *args: Any, **kwargs: Any) -> None:
705726
"""
706727
print("Checking for processed data in", self.processed_dir_main)
707728

708-
processed_name = self.processed_dir_main_file_names_dict["data"]
729+
processed_name = self.processed_main_file_names_dict["data"]
709730
if not os.path.isfile(os.path.join(self.processed_dir_main, processed_name)):
710731
print("Missing processed data file (`data.pkl` file)")
711732
os.makedirs(self.processed_dir_main, exist_ok=True)
@@ -796,7 +817,7 @@ def setup_processed(self) -> None:
796817
self._load_data_from_file(
797818
os.path.join(
798819
self.processed_dir_main,
799-
self.processed_dir_main_file_names_dict["data"],
820+
self.processed_main_file_names_dict["data"],
800821
)
801822
),
802823
os.path.join(self.processed_dir, self.processed_file_names_dict["data"]),
@@ -1118,47 +1139,23 @@ def processed_dir_main(self) -> str:
11181139
)
11191140

11201141
@property
1121-
def processed_dir(self) -> str:
1142+
def processed_main_file_names_dict(self) -> dict:
11221143
"""
1123-
Returns the specific directory path for processed data, including identifiers.
1144+
Returns a dictionary mapping processed data file names.
11241145
11251146
Returns:
1126-
str: The path to the processed data directory, including additional identifiers.
1127-
"""
1128-
return os.path.join(
1129-
self.processed_dir_main,
1130-
*self.identifier,
1131-
)
1132-
1133-
@property
1134-
def processed_dir_main_file_names_dict(self) -> dict:
1135-
"""
1136-
Returns a dictionary mapping processed data file names, processed by `prepare_data` method.
1137-
1138-
Returns:
1139-
dict: A dictionary mapping dataset types to their respective processed file names.
1147+
dict: A dictionary mapping dataset key to their respective file names.
11401148
For example, {"data": "data.pkl"}.
11411149
"""
11421150
return {"data": "data.pkl"}
11431151

11441152
@property
11451153
def processed_file_names_dict(self) -> dict:
11461154
"""
1147-
Returns a dictionary mapping processed and transformed data file names to their final formats, which are
1148-
processed by `setup` method.
1155+
Returns a dictionary for the processed and tokenized data files.
11491156
11501157
Returns:
1151-
dict: A dictionary mapping dataset types to their respective final file names.
1158+
dict: A dictionary mapping dataset keys to their respective file names.
11521159
For example, {"data": "data.pt"}.
11531160
"""
11541161
return {"data": "data.pt"}
1155-
1156-
@property
1157-
def processed_file_names(self) -> List[str]:
1158-
"""
1159-
Returns a list of file names for processed data.
1160-
1161-
Returns:
1162-
List[str]: A list of file names corresponding to the processed data.
1163-
"""
1164-
return list(self.processed_file_names_dict.values())

chebai/preprocessing/datasets/chebi.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ def prepare_data(self, *args: Any, **kwargs: Any) -> None:
185185
if not os.path.isfile(
186186
os.path.join(
187187
self._chebi_version_train_obj.processed_dir_main,
188-
self._chebi_version_train_obj.processed_dir_main_file_names_dict[
188+
self._chebi_version_train_obj.processed_main_file_names_dict[
189189
"data"
190190
],
191191
)
@@ -216,9 +216,7 @@ def _load_chebi(self, version: int) -> str:
216216
Returns:
217217
str: The file path of the loaded ChEBI ontology.
218218
"""
219-
chebi_name = (
220-
f"chebi.obo" if version == self.chebi_version else f"chebi_v{version}.obo"
221-
)
219+
chebi_name = self.raw_file_names_dict["chebi"]
222220
chebi_path = os.path.join(self.raw_dir, chebi_name)
223221
if not os.path.isfile(chebi_path):
224222
print(
@@ -540,6 +538,10 @@ def processed_dir(self) -> str:
540538
else:
541539
return os.path.join(res, f"single_{self.single_class}")
542540

541+
@property
542+
def raw_file_names_dict(self) -> dict:
543+
return {"chebi": "chebi.obo"}
544+
543545

544546
class JCIExtendedBase(_ChEBIDataExtractor):
545547

chebai/preprocessing/datasets/protein_pretraining.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def prepare_data(self, *args: Any, **kwargs: Any) -> None:
6464
*args: Additional positional arguments.
6565
**kwargs: Additional keyword arguments.
6666
"""
67-
processed_name = self.processed_dir_main_file_names_dict["data"]
67+
processed_name = self.processed_main_file_names_dict["data"]
6868
if not os.path.isfile(os.path.join(self.processed_dir_main, processed_name)):
6969
print("Missing processed data file (`data.pkl` file)")
7070
os.makedirs(self.processed_dir_main, exist_ok=True)

chebai/result/utils.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,13 @@ def _run_batch(batch, model, collate):
6666
return preds, labels
6767

6868

69+
def _concat_tuple(l):
70+
if isinstance(l[0], tuple):
71+
print(l[0])
72+
return tuple([torch.cat([t[i] for t in l]) for i in range(len(l[0]))])
73+
return torch.cat(l)
74+
75+
6976
def evaluate_model(
7077
model: ChebaiBaseNet,
7178
data_module: XYBaseDataModule,
@@ -125,12 +132,12 @@ def evaluate_model(
125132
if buffer_dir is not None:
126133
if n_saved * batch_size >= save_batch_size:
127134
torch.save(
128-
torch.cat(preds_list),
135+
_concat_tuple(preds_list),
129136
os.path.join(buffer_dir, f"preds{save_ind:03d}.pt"),
130137
)
131138
if labels_list[0] is not None:
132139
torch.save(
133-
torch.cat(labels_list),
140+
_concat_tuple(labels_list),
134141
os.path.join(buffer_dir, f"labels{save_ind:03d}.pt"),
135142
)
136143
preds_list = []
@@ -141,20 +148,20 @@ def evaluate_model(
141148
n_saved += 1
142149

143150
if buffer_dir is None:
144-
test_preds = torch.cat(preds_list)
151+
test_preds = _concat_tuple(preds_list)
145152
if labels_list is not None:
146-
test_labels = torch.cat(labels_list)
153+
test_labels = _concat_tuple(labels_list)
147154

148155
return test_preds, test_labels
149156
return test_preds, None
150157
else:
151158
torch.save(
152-
torch.cat(preds_list),
159+
_concat_tuple(preds_list),
153160
os.path.join(buffer_dir, f"preds{save_ind:03d}.pt"),
154161
)
155162
if labels_list[0] is not None:
156163
torch.save(
157-
torch.cat(labels_list),
164+
_concat_tuple(labels_list),
158165
os.path.join(buffer_dir, f"labels{save_ind:03d}.pt"),
159166
)
160167

0 commit comments

Comments
 (0)