@@ -35,7 +35,6 @@ class ChebaiBaseNet(LightningModule):
3535 def __init__ (
3636 self ,
3737 criterion : torch .nn .Module = None ,
38- out_dim : Optional [int ] = None ,
3938 train_metrics : Optional [torch .nn .Module ] = None ,
4039 val_metrics : Optional [torch .nn .Module ] = None ,
4140 test_metrics : Optional [torch .nn .Module ] = None ,
@@ -48,7 +47,7 @@ def __init__(
4847 self .save_hyperparameters (
4948 ignore = ["criterion" , "train_metrics" , "val_metrics" , "test_metrics" ]
5049 )
51- self .out_dim = out_dim
50+ self .out_dim = None
5251 if optimizer_kwargs :
5352 self .optimizer_kwargs = optimizer_kwargs
5453 else :
@@ -70,6 +69,14 @@ def __init_subclass__(cls, **kwargs):
7069 else :
7170 _MODEL_REGISTRY [cls .NAME ] = cls
7271
72+ def setup (self , stage : str ) -> None :
73+ if self .trainer and hasattr (self .trainer , "datamodule" ):
74+ self .out_dim = int (self .trainer .datamodule .hparams .num_of_labels )
75+ else :
76+ raise ValueError ("Trainer has no data module" )
77+ assert self .out_dim is not None , "Model output dimension is None"
78+ print (f"Output Dimension for the model: { self .out_dim } " )
79+
7380 def _get_prediction_and_labels (
7481 self , data : Dict [str , Any ], labels : torch .Tensor , output : torch .Tensor
7582 ) -> (torch .Tensor , torch .Tensor ):
0 commit comments