Skip to content

Commit 8dd8b60

Browse files
committed
model base: make forward method as abstract method
1 parent 2d816a6 commit 8dd8b60

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

chebai/models/base.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import logging
2+
from abc import ABC, abstractmethod
23
from typing import Any, Dict, Optional, Union
34

45
import torch
56
from lightning.pytorch.core.module import LightningModule
6-
from torchmetrics import Metric
77

88
from chebai.preprocessing.structures import XYData
99

@@ -12,7 +12,7 @@
1212
_MODEL_REGISTRY = dict()
1313

1414

15-
class ChebaiBaseNet(LightningModule):
15+
class ChebaiBaseNet(LightningModule, ABC):
1616
"""
1717
Base class for Chebai neural network models inheriting from PyTorch Lightning's LightningModule.
1818
@@ -315,6 +315,7 @@ def _log_metrics(self, prefix: str, metrics: torch.nn.Module, batch_size: int):
315315
logger=True,
316316
)
317317

318+
@abstractmethod
318319
def forward(self, x: Dict[str, Any]) -> torch.Tensor:
319320
"""
320321
Defines the forward pass.
@@ -325,7 +326,7 @@ def forward(self, x: Dict[str, Any]) -> torch.Tensor:
325326
Returns:
326327
torch.Tensor: The model output.
327328
"""
328-
raise NotImplementedError
329+
pass
329330

330331
def configure_optimizers(self, **kwargs) -> torch.optim.Optimizer:
331332
"""

0 commit comments

Comments
 (0)