Skip to content

Commit 4fda565

Browse files
committed
model base: make forward method as abstract method rebase ensemble_fr from protein_prediction to dev
1 parent f0e4758 commit 4fda565

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

chebai/models/base.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import logging
2-
from typing import Any, Dict, Optional, Union, Iterable
2+
3+
from abc import ABC, abstractmethod
4+
from typing import Any, Dict, Optional, Union
35

46
import torch
57
from lightning.pytorch.core.module import LightningModule
6-
from torchmetrics import Metric
78

89
from chebai.preprocessing.structures import XYData
910

@@ -12,7 +13,7 @@
1213
_MODEL_REGISTRY = dict()
1314

1415

15-
class ChebaiBaseNet(LightningModule):
16+
class ChebaiBaseNet(LightningModule, ABC):
1617
"""
1718
Base class for Chebai neural network models inheriting from PyTorch Lightning's LightningModule.
1819
@@ -347,6 +348,7 @@ def _log_metrics(self, prefix: str, metrics: torch.nn.Module, batch_size: int):
347348
logger=True,
348349
)
349350

351+
@abstractmethod
350352
def forward(self, x: Dict[str, Any]) -> torch.Tensor:
351353
"""
352354
Defines the forward pass.
@@ -357,7 +359,7 @@ def forward(self, x: Dict[str, Any]) -> torch.Tensor:
357359
Returns:
358360
torch.Tensor: The model output.
359361
"""
360-
raise NotImplementedError
362+
pass
361363

362364
def configure_optimizers(self, **kwargs) -> torch.optim.Optimizer:
363365
"""

0 commit comments

Comments
 (0)