Skip to content

Commit acd487e

Browse files
committed
datamodule: store num_of_labels to hparms
1 parent fd6dd01 commit acd487e

File tree

2 files changed

+21
-0
lines changed

2 files changed

+21
-0
lines changed

chebai/preprocessing/datasets/base.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -401,6 +401,21 @@ def setup(self, **kwargs):
401401
if not ("keep_reader" in kwargs and kwargs["keep_reader"]):
402402
self.reader.on_finish()
403403

404+
self._add_num_of_labels_to_hparams()
405+
406+
def _add_num_of_labels_to_hparams(self):
407+
num_of_labels = len(
408+
torch.load(
409+
os.path.join(
410+
self.processed_dir, self.processed_file_names_dict["data"]
411+
),
412+
weights_only=False,
413+
)[0]["labels"]
414+
)
415+
416+
print(f"Number of labels for loaded data: {num_of_labels}")
417+
self.hparams.num_of_labels = num_of_labels
418+
404419
def setup_processed(self):
405420
"""
406421
Setup the processed data.
@@ -541,6 +556,8 @@ def setup(self, **kwargs):
541556
for s in self.subsets:
542557
s.setup(**kwargs)
543558

559+
self._add_num_of_labels_to_hparams()
560+
544561
def dataloader(self, kind: str, **kwargs) -> DataLoader:
545562
"""
546563
Creates a DataLoader for a specific subset.

chebai/preprocessing/datasets/tox21.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,8 @@ def setup(self, **kwargs) -> None:
129129
):
130130
self.setup_processed()
131131

132+
self._add_num_of_labels_to_hparams()
133+
132134
def _load_data_from_file(self, input_file_path: str) -> List[Dict]:
133135
"""Loads data from a CSV file.
134136
@@ -311,6 +313,8 @@ def setup(self, **kwargs) -> None:
311313
):
312314
self.setup_processed()
313315

316+
self._add_num_of_labels_to_hparams()
317+
314318
def _load_dict(self, input_file_path: str) -> Generator[Dict, None, None]:
315319
"""Loads data from a CSV file as a generator.
316320

0 commit comments

Comments
 (0)