Skip to content

Commit 10e4afc

Browse files
authored
Merge pull request #74 from ChEB-AI/out_dim_dynamic
set `out_dim` dynamically
2 parents 0cd29cf + 5002437 commit 10e4afc

File tree

6 files changed

+103
-107
lines changed

6 files changed

+103
-107
lines changed

chebai/cli.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1-
from typing import Dict, Set
1+
from typing import Dict, Set, Type
22

33
from lightning.pytorch.cli import LightningArgumentParser, LightningCLI
44

5+
from chebai.preprocessing.datasets import XYBaseDataModule
56
from chebai.trainer.CustomTrainer import CustomTrainer
67

78

@@ -38,14 +39,35 @@ def add_arguments_to_parser(self, parser: LightningArgumentParser):
3839
Args:
3940
parser (LightningArgumentParser): Argument parser instance.
4041
"""
42+
43+
def call_data_methods(data: Type[XYBaseDataModule]):
44+
if data._num_of_labels is None:
45+
data.prepare_data()
46+
data.setup()
47+
return data.num_of_labels
48+
49+
parser.link_arguments(
50+
"data",
51+
"model.init_args.out_dim",
52+
apply_on="instantiate",
53+
compute_fn=call_data_methods,
54+
)
55+
56+
parser.link_arguments(
57+
"data.feature_vector_size",
58+
"model.init_args.input_dim",
59+
apply_on="instantiate",
60+
)
61+
4162
for kind in ("train", "val", "test"):
4263
for average in ("micro-f1", "macro-f1", "balanced-accuracy"):
4364
parser.link_arguments(
44-
"model.init_args.out_dim",
65+
"data.num_of_labels",
4566
f"model.init_args.{kind}_metrics.init_args.metrics.{average}.init_args.num_labels",
67+
apply_on="instantiate",
4668
)
4769
parser.link_arguments(
48-
"model.init_args.out_dim", "trainer.callbacks.init_args.num_labels"
70+
"data.num_of_labels", "trainer.callbacks.init_args.num_labels"
4971
)
5072

5173
@staticmethod

chebai/models/base.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def __init__(
3636
self,
3737
criterion: torch.nn.Module = None,
3838
out_dim: Optional[int] = None,
39+
input_dim: Optional[int] = None,
3940
train_metrics: Optional[torch.nn.Module] = None,
4041
val_metrics: Optional[torch.nn.Module] = None,
4142
test_metrics: Optional[torch.nn.Module] = None,
@@ -57,7 +58,12 @@ def __init__(
5758
*exclude_hyperparameter_logging,
5859
]
5960
)
61+
6062
self.out_dim = out_dim
63+
self.input_dim = input_dim
64+
assert out_dim is not None, "out_dim must be specified"
65+
assert input_dim is not None, "input_dim must be specified"
66+
6167
if optimizer_kwargs:
6268
self.optimizer_kwargs = optimizer_kwargs
6369
else:

chebai/preprocessing/datasets/base.py

Lines changed: 56 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,23 @@ def __init__(
119119
os.makedirs(os.path.join(self.processed_dir, self.fold_dir), exist_ok=True)
120120
self.save_hyperparameters()
121121

122+
self._num_of_labels = None
123+
self._feature_vector_size = None
124+
self._prepare_data_flag = 1
125+
self._setup_data_flag = 1
126+
127+
@property
128+
def num_of_labels(self):
129+
assert self._num_of_labels is not None, "num of labels must be set"
130+
return self._num_of_labels
131+
132+
@property
133+
def feature_vector_size(self):
134+
assert (
135+
self._feature_vector_size is not None
136+
), "size of feature vector must be set"
137+
return self._feature_vector_size
138+
122139
@property
123140
def identifier(self) -> tuple:
124141
"""Identifier for the dataset."""
@@ -390,7 +407,17 @@ def predict_dataloader(
390407
"""
391408
return self.dataloader(self.prediction_kind, shuffle=False, **kwargs)
392409

393-
def setup(self, **kwargs):
410+
def prepare_data(self, *args, **kwargs) -> None:
411+
if self._prepare_data_flag != 1:
412+
return
413+
414+
self._prepare_data_flag += 1
415+
self._perform_data_preparation(*args, **kwargs)
416+
417+
def _perform_data_preparation(self, *args, **kwargs) -> None:
418+
raise NotImplementedError
419+
420+
def setup(self, *args, **kwargs) -> None:
394421
"""
395422
Setup the data module.
396423
@@ -399,6 +426,11 @@ def setup(self, **kwargs):
399426
Args:
400427
**kwargs: Additional keyword arguments.
401428
"""
429+
if self._setup_data_flag != 1:
430+
return
431+
432+
self._setup_data_flag += 1
433+
402434
rank_zero_info(f"Check for processed data in {self.processed_dir}")
403435
rank_zero_info(f"Cross-validation enabled: {self.use_inner_cross_validation}")
404436
if any(
@@ -410,6 +442,21 @@ def setup(self, **kwargs):
410442
if not ("keep_reader" in kwargs and kwargs["keep_reader"]):
411443
self.reader.on_finish()
412444

445+
self._set_processed_data_props()
446+
447+
def _set_processed_data_props(self):
448+
449+
data_pt = torch.load(
450+
os.path.join(self.processed_dir, self.processed_file_names_dict["data"]),
451+
weights_only=False,
452+
)
453+
454+
self._num_of_labels = len(data_pt[0]["labels"])
455+
self._feature_vector_size = max(len(d["features"]) for d in data_pt)
456+
457+
print(f"Number of labels for loaded data: {self._num_of_labels}")
458+
print(f"Feature vector size: {self._feature_vector_size}")
459+
413460
def setup_processed(self):
414461
"""
415462
Setup the processed data.
@@ -482,18 +529,6 @@ def raw_file_names_dict(self) -> dict:
482529
"""
483530
raise NotImplementedError
484531

485-
@property
486-
def label_number(self) -> int:
487-
"""
488-
Returns the number of labels.
489-
490-
This property should be implemented by subclasses to provide the number of labels.
491-
492-
Returns:
493-
int: The number of labels. Returns -1 for seq2seq encoding.
494-
"""
495-
raise NotImplementedError
496-
497532

498533
class MergedDataset(XYBaseDataModule):
499534
MERGED = []
@@ -531,7 +566,7 @@ def __init__(
531566
os.makedirs(self.processed_dir, exist_ok=True)
532567
super(pl.LightningDataModule, self).__init__(**kwargs)
533568

534-
def prepare_data(self):
569+
def _perform_data_preparation(self):
535570
"""
536571
Placeholder for data preparation logic.
537572
"""
@@ -547,9 +582,15 @@ def setup(self, **kwargs):
547582
Args:
548583
**kwargs: Additional keyword arguments.
549584
"""
585+
if self._setup_data_flag != 1:
586+
return
587+
588+
self._setup_data_flag += 1
550589
for s in self.subsets:
551590
s.setup(**kwargs)
552591

592+
self._set_processed_data_props()
593+
553594
def dataloader(self, kind: str, **kwargs) -> DataLoader:
554595
"""
555596
Creates a DataLoader for a specific subset.
@@ -623,13 +664,6 @@ def processed_file_names(self) -> List[str]:
623664
"""
624665
return ["test.pt", "train.pt", "validation.pt"]
625666

626-
@property
627-
def label_number(self) -> int:
628-
"""
629-
Returns the number of labels from the first subset.
630-
"""
631-
return self.subsets[0].label_number
632-
633667
@property
634668
def limits(self):
635669
"""
@@ -725,7 +759,7 @@ def _validate_splits_file_path(splits_file_path: Optional[str]) -> Optional[str]
725759
return splits_file_path
726760

727761
# ------------------------------ Phase: Prepare data -----------------------------------
728-
def prepare_data(self, *args: Any, **kwargs: Any) -> None:
762+
def _perform_data_preparation(self, *args: Any, **kwargs: Any) -> None:
729763
"""
730764
Prepares the data for the dataset.
731765

chebai/preprocessing/datasets/chebi.py

Lines changed: 3 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def download(self):
5959
def raw_file_names(self):
6060
return ["test.pkl", "train.pkl", "validation.pkl"]
6161

62-
def prepare_data(self, *args, **kwargs):
62+
def _perform_data_preparation(self, *args, **kwargs):
6363
print("Check for raw data in", self.raw_dir)
6464
if any(
6565
not os.path.isfile(os.path.join(self.raw_dir, f))
@@ -88,10 +88,6 @@ def setup_processed(self):
8888
os.path.join(self.processed_dir, f"{k}.pt"),
8989
)
9090

91-
@property
92-
def label_number(self):
93-
return 500
94-
9591

9692
class JCIData(JCIBase):
9793
READER = dr.OrdReader
@@ -158,7 +154,7 @@ def __init__(
158154
)
159155

160156
# ------------------------------ Phase: Prepare data -----------------------------------
161-
def prepare_data(self, *args: Any, **kwargs: Any) -> None:
157+
def _perform_data_preparation(self, *args: Any, **kwargs: Any) -> None:
162158
"""
163159
Prepares the data for the Chebi dataset.
164160
@@ -179,7 +175,7 @@ def prepare_data(self, *args: Any, **kwargs: Any) -> None:
179175
Returns:
180176
None
181177
"""
182-
super().prepare_data(args, kwargs)
178+
super()._perform_data_preparation(args, kwargs)
183179

184180
if self.chebi_version_train is not None:
185181
if not os.path.isfile(
@@ -545,10 +541,6 @@ def raw_file_names_dict(self) -> dict:
545541

546542
class JCIExtendedBase(_ChEBIDataExtractor):
547543

548-
@property
549-
def label_number(self):
550-
return 500
551-
552544
@property
553545
def _name(self):
554546
return "JCI_extended"
@@ -573,16 +565,6 @@ class ChEBIOverX(_ChEBIDataExtractor):
573565
READER: dr.ChemDataReader = dr.ChemDataReader
574566
THRESHOLD: int = None
575567

576-
@property
577-
def label_number(self) -> int:
578-
"""
579-
Returns the number of labels in the dataset.
580-
581-
Returns:
582-
int: The number of labels.
583-
"""
584-
return 854
585-
586568
@property
587569
def _name(self) -> str:
588570
"""
@@ -675,17 +657,6 @@ class ChEBIOver100(ChEBIOverX):
675657

676658
THRESHOLD: int = 100
677659

678-
def label_number(self) -> int:
679-
"""
680-
Returns the number of labels in the dataset.
681-
682-
Overrides the base class method to return the correct number of labels for this threshold.
683-
684-
Returns:
685-
int: The number of labels.
686-
"""
687-
return 854
688-
689660

690661
class ChEBIOver50(ChEBIOverX):
691662
"""
@@ -699,17 +670,6 @@ class ChEBIOver50(ChEBIOverX):
699670

700671
THRESHOLD: int = 50
701672

702-
def label_number(self) -> int:
703-
"""
704-
Returns the number of labels in the dataset.
705-
706-
Overrides the base class method to return the correct number of labels for this threshold.
707-
708-
Returns:
709-
int: The number of labels.
710-
"""
711-
return 1332
712-
713673

714674
class ChEBIOver100DeepSMILES(ChEBIOverXDeepSMILES, ChEBIOver100):
715675
"""

chebai/preprocessing/datasets/pubchem.py

Lines changed: 1 addition & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ def processed_file_names(self) -> List[str]:
179179
"""
180180
return ["test.pt", "train.pt", "validation.pt"]
181181

182-
def prepare_data(self, *args, **kwargs):
182+
def _perform_data_preparation(self, *args, **kwargs):
183183
"""
184184
Checks for raw data and downloads if necessary.
185185
"""
@@ -692,13 +692,6 @@ class PubchemChem(PubChem):
692692

693693
READER: Type[dr.ChemDataReader] = dr.ChemDataReader
694694

695-
@property
696-
def label_number(self) -> int:
697-
"""
698-
Returns the label number.
699-
"""
700-
return -1
701-
702695

703696
class PubchemBPE(PubChem):
704697
"""
@@ -712,13 +705,6 @@ class PubchemBPE(PubChem):
712705

713706
READER: Type[dr.ChemBPEReader] = dr.ChemBPEReader
714707

715-
@property
716-
def label_number(self) -> int:
717-
"""
718-
Returns the label number.
719-
"""
720-
return -1
721-
722708

723709
class SWJChem(SWJPreChem):
724710
"""
@@ -732,13 +718,6 @@ class SWJChem(SWJPreChem):
732718

733719
READER: Type[dr.ChemDataUnlabeledReader] = dr.ChemDataUnlabeledReader
734720

735-
@property
736-
def label_number(self) -> int:
737-
"""
738-
Returns the label number.
739-
"""
740-
return -1
741-
742721

743722
class SWJBPE(SWJPreChem):
744723
"""
@@ -752,13 +731,6 @@ class SWJBPE(SWJPreChem):
752731

753732
READER: Type[dr.ChemBPEReader] = dr.ChemBPEReader
754733

755-
@property
756-
def label_number(self) -> int:
757-
"""
758-
Returns the label number.
759-
"""
760-
return -1
761-
762734

763735
class PubChemTokens(PubChem):
764736
"""

0 commit comments

Comments
 (0)