Skip to content

Commit 1285e5a

Browse files
committed
add out_dim to hparams explicitly
1 parent f74b57c commit 1285e5a

File tree

1 file changed

+8
-5
lines changed

1 file changed

+8
-5
lines changed

chebai/models/base.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import logging
2-
from typing import Any, Dict, Optional, Union, Iterable
2+
from typing import Any, Dict, Iterable, Optional, Union
33

44
import torch
55
from lightning.pytorch.core.module import LightningModule
@@ -49,6 +49,11 @@ def __init__(
4949
if exclude_hyperparameter_logging is None:
5050
exclude_hyperparameter_logging = tuple()
5151
self.criterion = criterion
52+
assert out_dim is not None, "out_dim must be specified"
53+
assert input_dim is not None, "input_dim must be specified"
54+
self.out_dim = out_dim
55+
self.input_dim = input_dim
56+
5257
self.save_hyperparameters(
5358
ignore=[
5459
"criterion",
@@ -59,10 +64,8 @@ def __init__(
5964
]
6065
)
6166

62-
self.out_dim = out_dim
63-
self.input_dim = input_dim
64-
assert out_dim is not None, "out_dim must be specified"
65-
assert input_dim is not None, "input_dim must be specified"
67+
self.hparams["out_dim"] = out_dim
68+
self.hparams["input_dim"] = input_dim
6669

6770
if optimizer_kwargs:
6871
self.optimizer_kwargs = optimizer_kwargs

0 commit comments

Comments
 (0)