@@ -31,16 +31,13 @@ class ElectraPre(ChebaiBaseNet):
3131 **kwargs: Additional keyword arguments (passed to parent class).
3232
3333 Attributes:
34- NAME (str): Name of the ElectraPre model.
3534 generator_config (ElectraConfig): Configuration for the generator model.
3635 generator (ElectraForMaskedLM): Generator model for masked language modeling.
3736 discriminator_config (ElectraConfig): Configuration for the discriminator model.
3837 discriminator (ElectraForPreTraining): Discriminator model for pre-training.
3938 replace_p (float): Probability of replacing tokens during training.
4039 """
4140
42- NAME = "ElectraPre"
43-
4441 def __init__ (self , config : Dict [str , Any ] = None , ** kwargs : Any ):
4542 super ().__init__ (config = config , ** kwargs )
4643 self .generator_config = ElectraConfig (** config ["generator" ])
@@ -174,12 +171,8 @@ class Electra(ChebaiBaseNet):
174171 load_prefix (str, optional): Prefix to filter the state_dict keys from the pretrained checkpoint. Defaults to None.
175172 **kwargs: Additional keyword arguments.
176173
177- Attributes:
178- NAME (str): Name of the Electra model.
179174 """
180175
181- NAME = "Electra"
182-
183176 def _process_batch (self , batch : Dict [str , Any ], batch_idx : int ) -> Dict [str , Any ]:
184177 """
185178 Process a batch of data.
@@ -328,7 +321,7 @@ def forward(self, data: Dict[str, Tensor], **kwargs: Any) -> Dict[str, Any]:
328321 inp = self .electra .embeddings .forward (data ["features" ].int ())
329322 except RuntimeError as e :
330323 print (f"RuntimeError at forward: { e } " )
331- print (f' data[features]: { data [" features" ] } ' )
324+ print (f" data[features]: { data [' features' ] } " )
332325 raise e
333326 inp = self .word_dropout (inp )
334327 electra = self .electra (inputs_embeds = inp , ** kwargs )
@@ -340,8 +333,6 @@ def forward(self, data: Dict[str, Tensor], **kwargs: Any) -> Dict[str, Any]:
340333
341334
342335class ElectraLegacy (ChebaiBaseNet ):
343- NAME = "ElectraLeg"
344-
345336 def __init__ (self , ** kwargs ):
346337 super ().__init__ (** kwargs )
347338 self .config = ElectraConfig (** kwargs ["config" ], output_attentions = True )
@@ -374,8 +365,6 @@ def forward(self, data):
374365
375366
376367class ConeElectra (ChebaiBaseNet ):
377- NAME = "ConeElectra"
378-
379368 def _process_batch (self , batch , batch_idx ):
380369 mask = pad_sequence (
381370 [torch .ones (l + 1 , device = self .device ) for l in batch .lens ],
0 commit comments