Skip to content

Commit 2633810

Browse files
authored
Merge pull request #100 from ChEB-AI/fix/implicit_model_resigtry
Fix: Use __name__ instead of custom NAME attribute for class registration
2 parents 26329cb + 52ef0d4 commit 2633810

File tree

8 files changed

+5
-32
lines changed

8 files changed

+5
-32
lines changed

chebai/models/base.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,8 @@ class ChebaiBaseNet(LightningModule, ABC):
2626
optimizer_kwargs (Dict[str, Any], optional): Additional keyword arguments for the optimizer. Defaults to None.
2727
**kwargs: Additional keyword arguments.
2828
29-
Attributes:
30-
NAME (str): The name of the model.
3129
"""
3230

33-
NAME = None
34-
3531
def __init__(
3632
self,
3733
criterion: torch.nn.Module = None,
@@ -88,10 +84,10 @@ def __init_subclass__(cls, **kwargs):
8884
Args:
8985
**kwargs: Additional keyword arguments.
9086
"""
91-
if cls.NAME in _MODEL_REGISTRY:
92-
raise ValueError(f"Model {cls.NAME} does already exist")
87+
if cls.__name__ in _MODEL_REGISTRY:
88+
raise ValueError(f"Model {cls.__name__} does already exist")
9389
else:
94-
_MODEL_REGISTRY[cls.NAME] = cls
90+
_MODEL_REGISTRY[cls.__name__] = cls
9591

9692
def _get_prediction_and_labels(
9793
self, data: Dict[str, Any], labels: torch.Tensor, output: torch.Tensor

chebai/models/chemberta.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,6 @@
2020

2121

2222
class ChembertaPre(ChebaiBaseNet):
23-
NAME = "ChembertaPre"
24-
2523
def __init__(self, p=0.2, **kwargs):
2624
super().__init__(**kwargs)
2725
self._p = p
@@ -47,8 +45,6 @@ def forward(self, data):
4745

4846

4947
class Chemberta(ChebaiBaseNet):
50-
NAME = "Chemberta"
51-
5248
def __init__(self, **kwargs):
5349
# Remove this property in order to prevent it from being stored as a
5450
# hyper parameter

chebai/models/chemyk.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,6 @@
1515

1616

1717
class ChemYK(ChebaiBaseNet):
18-
NAME = "ChemYK"
19-
2018
def __init__(self, in_d, out_d, num_classes, **kwargs):
2119
super().__init__(num_classes, **kwargs)
2220
d_internal = in_d

chebai/models/electra.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -31,16 +31,13 @@ class ElectraPre(ChebaiBaseNet):
3131
**kwargs: Additional keyword arguments (passed to parent class).
3232
3333
Attributes:
34-
NAME (str): Name of the ElectraPre model.
3534
generator_config (ElectraConfig): Configuration for the generator model.
3635
generator (ElectraForMaskedLM): Generator model for masked language modeling.
3736
discriminator_config (ElectraConfig): Configuration for the discriminator model.
3837
discriminator (ElectraForPreTraining): Discriminator model for pre-training.
3938
replace_p (float): Probability of replacing tokens during training.
4039
"""
4140

42-
NAME = "ElectraPre"
43-
4441
def __init__(self, config: Dict[str, Any] = None, **kwargs: Any):
4542
super().__init__(config=config, **kwargs)
4643
self.generator_config = ElectraConfig(**config["generator"])
@@ -174,12 +171,8 @@ class Electra(ChebaiBaseNet):
174171
load_prefix (str, optional): Prefix to filter the state_dict keys from the pretrained checkpoint. Defaults to None.
175172
**kwargs: Additional keyword arguments.
176173
177-
Attributes:
178-
NAME (str): Name of the Electra model.
179174
"""
180175

181-
NAME = "Electra"
182-
183176
def _process_batch(self, batch: Dict[str, Any], batch_idx: int) -> Dict[str, Any]:
184177
"""
185178
Process a batch of data.
@@ -328,7 +321,7 @@ def forward(self, data: Dict[str, Tensor], **kwargs: Any) -> Dict[str, Any]:
328321
inp = self.electra.embeddings.forward(data["features"].int())
329322
except RuntimeError as e:
330323
print(f"RuntimeError at forward: {e}")
331-
print(f'data[features]: {data["features"]}')
324+
print(f"data[features]: {data['features']}")
332325
raise e
333326
inp = self.word_dropout(inp)
334327
electra = self.electra(inputs_embeds=inp, **kwargs)
@@ -340,8 +333,6 @@ def forward(self, data: Dict[str, Tensor], **kwargs: Any) -> Dict[str, Any]:
340333

341334

342335
class ElectraLegacy(ChebaiBaseNet):
343-
NAME = "ElectraLeg"
344-
345336
def __init__(self, **kwargs):
346337
super().__init__(**kwargs)
347338
self.config = ElectraConfig(**kwargs["config"], output_attentions=True)
@@ -374,8 +365,6 @@ def forward(self, data):
374365

375366

376367
class ConeElectra(ChebaiBaseNet):
377-
NAME = "ConeElectra"
378-
379368
def _process_batch(self, batch, batch_idx):
380369
mask = pad_sequence(
381370
[torch.ones(l + 1, device=self.device) for l in batch.lens],

chebai/models/external/__init__.py

Whitespace-only changes.

chebai/models/ffn.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,13 @@
99
class FFN(ChebaiBaseNet):
1010
# Reference: https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/models.py#L121-L139
1111

12-
NAME = "FFN"
13-
1412
def __init__(
1513
self,
1614
input_size: int,
1715
hidden_layers: List[int] = [
1816
1024,
1917
],
20-
**kwargs
18+
**kwargs,
2119
):
2220
super().__init__(**kwargs)
2321

chebai/models/lstm.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,6 @@
1010

1111

1212
class ChemLSTM(ChebaiBaseNet):
13-
NAME = "LSTM"
14-
1513
def __init__(self, in_d, out_d, num_classes, **kwargs):
1614
super().__init__(num_classes, **kwargs)
1715
self.lstm = nn.LSTM(in_d, out_d, batch_first=True)

chebai/models/recursive.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,6 @@
1111

1212

1313
class Recursive(ChebaiBaseNet):
14-
NAME = "REC"
15-
1614
def __init__(self, in_d, out_d, num_classes, **kwargs):
1715
super().__init__(num_classes, **kwargs)
1816
mem_len = in_d

0 commit comments

Comments
 (0)