Skip to content

Commit 9e20960

Browse files
committed
modify logic to call prepare and setup during instantiation of data module
- Lightning-AI/pytorch-lightning#20602 (comment)
1 parent 381452e commit 9e20960

File tree

9 files changed

+101
-41
lines changed

9 files changed

+101
-41
lines changed

chebai/cli.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -38,15 +38,26 @@ def add_arguments_to_parser(self, parser: LightningArgumentParser):
3838
Args:
3939
parser (LightningArgumentParser): Argument parser instance.
4040
"""
41-
# for kind in ("train", "val", "test"):
42-
# for average in ("micro-f1", "macro-f1", "balanced-accuracy"):
43-
# parser.link_arguments(
44-
# "model.init_args.out_dim",
45-
# f"model.init_args.{kind}_metrics.init_args.metrics.{average}.init_args.num_labels",
46-
# )
47-
# parser.link_arguments(
48-
# "model.init_args.out_dim", "trainer.callbacks.init_args.num_labels"
49-
# )
41+
42+
parser.link_arguments(
43+
"data.num_of_labels", "model.init_args.out_dim", apply_on="instantiate"
44+
)
45+
parser.link_arguments(
46+
"data.feature_vector_size",
47+
"model.init_args.input_dim",
48+
apply_on="instantiate",
49+
)
50+
51+
for kind in ("train", "val", "test"):
52+
for average in ("micro-f1", "macro-f1", "balanced-accuracy"):
53+
parser.link_arguments(
54+
"data.num_of_labels",
55+
f"model.init_args.{kind}_metrics.init_args.metrics.{average}.init_args.num_labels",
56+
apply_on="instantiate",
57+
)
58+
parser.link_arguments(
59+
"data.num_of_labels", "trainer.callbacks.init_args.num_labels"
60+
)
5061
parser.link_arguments(
5162
"data", "model.init_args.criterion.init_args.data_extractor"
5263
)

chebai/models/base.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ class ChebaiBaseNet(LightningModule):
3535
def __init__(
3636
self,
3737
criterion: torch.nn.Module = None,
38+
out_dim: Optional[int] = None,
39+
input_dim: Optional[int] = None,
3840
train_metrics: Optional[torch.nn.Module] = None,
3941
val_metrics: Optional[torch.nn.Module] = None,
4042
test_metrics: Optional[torch.nn.Module] = None,
@@ -47,7 +49,12 @@ def __init__(
4749
self.save_hyperparameters(
4850
ignore=["criterion", "train_metrics", "val_metrics", "test_metrics"]
4951
)
50-
self.out_dim = None
52+
53+
self.out_dim = out_dim
54+
self.input_dim = input_dim
55+
assert out_dim is not None, "out_dim must be specified"
56+
assert input_dim is not None, "input_dim must be specified"
57+
5158
if optimizer_kwargs:
5259
self.optimizer_kwargs = optimizer_kwargs
5360
else:
@@ -69,14 +76,6 @@ def __init_subclass__(cls, **kwargs):
6976
else:
7077
_MODEL_REGISTRY[cls.NAME] = cls
7178

72-
def setup(self, stage: str) -> None:
73-
if self.trainer and hasattr(self.trainer, "datamodule"):
74-
self.out_dim = int(self.trainer.datamodule.hparams.num_of_labels)
75-
else:
76-
raise ValueError("Trainer has no data module")
77-
assert self.out_dim is not None, "Model output dimension is None"
78-
print(f"Output Dimension for the model: {self.out_dim}")
79-
8079
def _get_prediction_and_labels(
8180
self, data: Dict[str, Any], labels: torch.Tensor, output: torch.Tensor
8281
) -> (torch.Tensor, torch.Tensor):

chebai/models/ffn.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,9 @@ def __init__(
2020
**kwargs
2121
):
2222
super().__init__(**kwargs)
23-
self.input_size = input_size
24-
self.hidden_layers = hidden_layers
25-
26-
def setup(self, stage: str) -> None:
27-
super().setup(stage)
28-
2923
layers = []
30-
current_layer_input_size = self.input_size
31-
for hidden_dim in self.hidden_layers:
24+
current_layer_input_size = input_size
25+
for hidden_dim in hidden_layers:
3226
layers.append(MLPBlock(current_layer_input_size, hidden_dim))
3327
layers.append(Residual(MLPBlock(hidden_dim, hidden_dim)))
3428
current_layer_input_size = hidden_dim

chebai/preprocessing/datasets/base.py

Lines changed: 49 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,25 @@ def __init__(
117117
os.makedirs(os.path.join(self.processed_dir, self.fold_dir), exist_ok=True)
118118
self.save_hyperparameters()
119119

120+
self._num_of_labels = None
121+
self._feature_vector_size = None
122+
self._prepare_data_flag = 1
123+
self._setup_data_flag = 1
124+
self.prepare_data()
125+
self.setup()
126+
127+
@property
128+
def num_of_labels(self):
129+
assert self._num_of_labels is not None, "num of labels must be set"
130+
return self._num_of_labels
131+
132+
@property
133+
def feature_vector_size(self):
134+
assert (
135+
self._feature_vector_size is not None
136+
), "size of feature vector must be set"
137+
return self._feature_vector_size
138+
120139
@property
121140
def identifier(self) -> tuple:
122141
"""Identifier for the dataset."""
@@ -381,6 +400,12 @@ def predict_dataloader(
381400
"""
382401
return self.dataloader(self.prediction_kind, shuffle=False, **kwargs)
383402

403+
def prepare_data(self) -> None:
404+
if self._prepare_data_flag != 1:
405+
return
406+
407+
self._prepare_data_flag += 1
408+
384409
def setup(self, **kwargs):
385410
"""
386411
Setup the data module.
@@ -390,6 +415,11 @@ def setup(self, **kwargs):
390415
Args:
391416
**kwargs: Additional keyword arguments.
392417
"""
418+
if self._setup_data_flag != 1:
419+
return
420+
421+
self._setup_data_flag += 1
422+
393423
rank_zero_info(f"Check for processed data in {self.processed_dir}")
394424
rank_zero_info(f"Cross-validation enabled: {self.use_inner_cross_validation}")
395425
if any(
@@ -401,20 +431,20 @@ def setup(self, **kwargs):
401431
if not ("keep_reader" in kwargs and kwargs["keep_reader"]):
402432
self.reader.on_finish()
403433

404-
self._add_num_of_labels_to_hparams()
434+
self._set_processed_data_props()
405435

406-
def _add_num_of_labels_to_hparams(self):
407-
num_of_labels = len(
408-
torch.load(
409-
os.path.join(
410-
self.processed_dir, self.processed_file_names_dict["data"]
411-
),
412-
weights_only=False,
413-
)[0]["labels"]
414-
)
436+
def _set_processed_data_props(self):
415437

416-
print(f"Number of labels for loaded data: {num_of_labels}")
417-
self.hparams.num_of_labels = num_of_labels
438+
single_data_instance = torch.load(
439+
os.path.join(self.processed_dir, self.processed_file_names_dict["data"]),
440+
weights_only=False,
441+
)[0]
442+
443+
self._num_of_labels = len(single_data_instance["labels"])
444+
self._feature_vector_size = len(single_data_instance["features"])
445+
446+
print(f"Number of labels for loaded data: {self._num_of_labels}")
447+
print(f"Feature vector size: {self._feature_vector_size}")
418448

419449
def setup_processed(self):
420450
"""
@@ -541,6 +571,7 @@ def prepare_data(self):
541571
"""
542572
Placeholder for data preparation logic.
543573
"""
574+
super().prepare_data()
544575
for s in self.subsets:
545576
s.prepare_data()
546577

@@ -553,10 +584,14 @@ def setup(self, **kwargs):
553584
Args:
554585
**kwargs: Additional keyword arguments.
555586
"""
587+
if self._setup_data_flag != 1:
588+
return
589+
590+
self._setup_data_flag += 1
556591
for s in self.subsets:
557592
s.setup(**kwargs)
558593

559-
self._add_num_of_labels_to_hparams()
594+
self._set_processed_data_props()
560595

561596
def dataloader(self, kind: str, **kwargs) -> DataLoader:
562597
"""
@@ -752,6 +787,7 @@ def prepare_data(self, *args: Any, **kwargs: Any) -> None:
752787
Returns:
753788
None
754789
"""
790+
super().prepare_data()
755791
print("Checking for processed data in", self.processed_dir_main)
756792

757793
processed_name = self.processed_main_file_names_dict["data"]

chebai/preprocessing/datasets/chebi.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ def raw_file_names(self):
6060
return ["test.pkl", "train.pkl", "validation.pkl"]
6161

6262
def prepare_data(self, *args, **kwargs):
63+
super().prepare_data()
6364
print("Check for raw data in", self.raw_dir)
6465
if any(
6566
not os.path.isfile(os.path.join(self.raw_dir, f))

chebai/preprocessing/datasets/deepGO/go_uniprot.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -783,6 +783,11 @@ def prepare_data(self, *args: Any, **kwargs: Any) -> None:
783783
Raises:
784784
FileNotFoundError: If the processed data file does not exist.
785785
"""
786+
if self._prepare_data_flag != 1:
787+
return
788+
789+
self._prepare_data_flag += 1
790+
786791
print("Checking for processed data in", self.processed_dir_main)
787792

788793
processed_name = self.processed_main_file_names_dict["data"]

chebai/preprocessing/datasets/deepGO/protein_pretraining.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,11 @@ def prepare_data(self, *args: Any, **kwargs: Any) -> None:
6464
*args: Additional positional arguments.
6565
**kwargs: Additional keyword arguments.
6666
"""
67+
if self._prepare_data_flag != 1:
68+
return
69+
70+
self._prepare_data_flag += 1
71+
6772
processed_name = self.processed_main_file_names_dict["data"]
6873
if not os.path.isfile(os.path.join(self.processed_dir_main, processed_name)):
6974
print("Missing processed data file (`data.pkl` file)")

chebai/preprocessing/datasets/pubchem.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,7 @@ def prepare_data(self, *args, **kwargs):
183183
"""
184184
Checks for raw data and downloads if necessary.
185185
"""
186+
super().prepare_data()
186187
print("Check for raw data in", self.raw_dir)
187188
if any(
188189
not os.path.isfile(os.path.join(self.raw_dir, f))

chebai/preprocessing/datasets/tox21.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,10 @@ def setup_processed(self) -> None:
118118

119119
def setup(self, **kwargs) -> None:
120120
"""Sets up the dataset by downloading and processing if necessary."""
121+
if self._setup_data_flag != 1:
122+
return
123+
124+
self._setup_data_flag += 1
121125
if any(
122126
not os.path.isfile(os.path.join(self.raw_dir, f))
123127
for f in self.raw_file_names
@@ -129,7 +133,7 @@ def setup(self, **kwargs) -> None:
129133
):
130134
self.setup_processed()
131135

132-
self._add_num_of_labels_to_hparams()
136+
self._set_processed_data_props()
133137

134138
def _load_data_from_file(self, input_file_path: str) -> List[Dict]:
135139
"""Loads data from a CSV file.
@@ -302,6 +306,10 @@ def setup_processed(self) -> None:
302306

303307
def setup(self, **kwargs) -> None:
304308
"""Sets up the dataset by downloading and processing if necessary."""
309+
if self._setup_data_flag != 1:
310+
return
311+
312+
self._setup_data_flag += 1
305313
if any(
306314
not os.path.isfile(os.path.join(self.raw_dir, f))
307315
for f in self.raw_file_names
@@ -313,7 +321,7 @@ def setup(self, **kwargs) -> None:
313321
):
314322
self.setup_processed()
315323

316-
self._add_num_of_labels_to_hparams()
324+
self._set_processed_data_props()
317325

318326
def _load_dict(self, input_file_path: str) -> Generator[Dict, None, None]:
319327
"""Loads data from a CSV file as a generator.

0 commit comments

Comments
 (0)