Skip to content

Commit 9a48fdf

Browse files
committed
lightning module: retrieve num_of_labels from data module
1 parent acd487e commit 9a48fdf

File tree

2 files changed

+18
-11
lines changed

2 files changed

+18
-11
lines changed

chebai/cli.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -38,15 +38,15 @@ def add_arguments_to_parser(self, parser: LightningArgumentParser):
3838
Args:
3939
parser (LightningArgumentParser): Argument parser instance.
4040
"""
41-
for kind in ("train", "val", "test"):
42-
for average in ("micro-f1", "macro-f1", "balanced-accuracy"):
43-
parser.link_arguments(
44-
"model.init_args.out_dim",
45-
f"model.init_args.{kind}_metrics.init_args.metrics.{average}.init_args.num_labels",
46-
)
47-
parser.link_arguments(
48-
"model.init_args.out_dim", "trainer.callbacks.init_args.num_labels"
49-
)
41+
# for kind in ("train", "val", "test"):
42+
# for average in ("micro-f1", "macro-f1", "balanced-accuracy"):
43+
# parser.link_arguments(
44+
# "model.init_args.out_dim",
45+
# f"model.init_args.{kind}_metrics.init_args.metrics.{average}.init_args.num_labels",
46+
# )
47+
# parser.link_arguments(
48+
# "model.init_args.out_dim", "trainer.callbacks.init_args.num_labels"
49+
# )
5050
parser.link_arguments(
5151
"data", "model.init_args.criterion.init_args.data_extractor"
5252
)

chebai/models/base.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@ class ChebaiBaseNet(LightningModule):
3535
def __init__(
3636
self,
3737
criterion: torch.nn.Module = None,
38-
out_dim: Optional[int] = None,
3938
train_metrics: Optional[torch.nn.Module] = None,
4039
val_metrics: Optional[torch.nn.Module] = None,
4140
test_metrics: Optional[torch.nn.Module] = None,
@@ -48,7 +47,7 @@ def __init__(
4847
self.save_hyperparameters(
4948
ignore=["criterion", "train_metrics", "val_metrics", "test_metrics"]
5049
)
51-
self.out_dim = out_dim
50+
self.out_dim = None
5251
if optimizer_kwargs:
5352
self.optimizer_kwargs = optimizer_kwargs
5453
else:
@@ -70,6 +69,14 @@ def __init_subclass__(cls, **kwargs):
7069
else:
7170
_MODEL_REGISTRY[cls.NAME] = cls
7271

72+
def setup(self, stage: str) -> None:
73+
if self.trainer and hasattr(self.trainer, "datamodule"):
74+
self.out_dim = int(self.trainer.datamodule.hparams.num_of_labels)
75+
else:
76+
raise ValueError("Trainer has no data module")
77+
assert self.out_dim is not None, "Model output dimension is None"
78+
print(f"Output Dimension for the model: {self.out_dim}")
79+
7380
def _get_prediction_and_labels(
7481
self, data: Dict[str, Any], labels: torch.Tensor, output: torch.Tensor
7582
) -> (torch.Tensor, torch.Tensor):

0 commit comments

Comments
 (0)