@@ -161,17 +161,18 @@ def filter_dict(d: Dict[str, Any], filter_key: str) -> Dict[str, Any]:
161161 }
162162
163163
164- class Electra (ChebaiBaseNet ):
165- """
166- Electra model implementation inherited from ChebaiBaseNet.
164+ class ElectraProcessingMixIn :
165+ """Mixin class for processing batches and outputs for Electra models."""
167166
168- Args:
169- config (Dict[str, Any], optional): Configuration parameters for the Electra model. Defaults to None.
170- pretrained_checkpoint (str, optional): Path to the pretrained checkpoint file. Defaults to None.
171- load_prefix (str, optional): Prefix to filter the state_dict keys from the pretrained checkpoint. Defaults to None.
172- **kwargs: Additional keyword arguments.
167+ @property
168+ def as_pretrained (self ) -> ElectraModel :
169+ """
170+ Get the pretrained Electra model.
173171
174- """
172+ Returns:
173+ ElectraModel: The pretrained Electra model.
174+ """
175+ return self .electra .electra
175176
176177 def _process_batch (self , batch : Dict [str , Any ], batch_idx : int ) -> Dict [str , Any ]:
177178 """
@@ -209,15 +210,61 @@ def _process_batch(self, batch: Dict[str, Any], batch_idx: int) -> Dict[str, Any
209210 idents = batch .additional_fields ["idents" ],
210211 )
211212
212- @property
213- def as_pretrained (self ) -> ElectraModel :
213+ def _process_for_loss (
214+ self ,
215+ model_output : Dict [str , Tensor ],
216+ labels : Tensor ,
217+ loss_kwargs : Dict [str , Any ],
218+ ) -> Tuple [Tensor , Tensor , Dict [str , Any ]]:
214219 """
215- Get the pretrained Electra model.
220+ Process the model output for calculating the loss.
221+
222+ Args:
223+ model_output (Dict[str, Tensor]): The output of the model.
224+ labels (Tensor): The target labels.
225+ loss_kwargs (Dict[str, Any]): Additional loss arguments.
216226
217227 Returns:
218- ElectraModel: The pretrained Electra model.
228+ tuple: A tuple containing the processed model output, labels, and loss arguments .
219229 """
220- return self .electra .electra
230+ kwargs_copy = dict (loss_kwargs )
231+ if labels is not None :
232+ labels = labels .float ()
233+ return model_output ["logits" ], labels , kwargs_copy
234+
235+ def _get_prediction_and_labels (
236+ self , data : Dict [str , Any ], labels : Tensor , model_output : Dict [str , Tensor ]
237+ ) -> Tuple [Tensor , Tensor ]:
238+ """
239+ Get the predictions and labels from the model output. Applies a sigmoid to the model output.
240+
241+ Args:
242+ data (Dict[str, Any]): The input data.
243+ labels (Tensor): The target labels.
244+ model_output (Dict[str, Tensor]): The output of the model.
245+
246+ Returns:
247+ tuple: A tuple containing the predictions and labels.
248+ """
249+ d = model_output ["logits" ]
250+ loss_kwargs = data .get ("loss_kwargs" , dict ())
251+ if "non_null_labels" in loss_kwargs :
252+ n = loss_kwargs ["non_null_labels" ]
253+ d = d [n ]
254+ return torch .sigmoid (d ), labels .int () if labels is not None else None
255+
256+
257+ class Electra (ElectraProcessingMixIn , ChebaiBaseNet ):
258+ """
259+ Electra model implementation inherited from ChebaiBaseNet.
260+
261+ Args:
262+ config (Dict[str, Any], optional): Configuration parameters for the Electra model. Defaults to None.
263+ pretrained_checkpoint (str, optional): Path to the pretrained checkpoint file. Defaults to None.
264+ load_prefix (str, optional): Prefix to filter the state_dict keys from the pretrained checkpoint. Defaults to None.
265+ **kwargs: Additional keyword arguments.
266+
267+ """
221268
222269 def __init__ (
223270 self ,
@@ -262,49 +309,6 @@ def __init__(
262309 else :
263310 self .electra = ElectraModel (config = self .config )
264311
265- def _process_for_loss (
266- self ,
267- model_output : Dict [str , Tensor ],
268- labels : Tensor ,
269- loss_kwargs : Dict [str , Any ],
270- ) -> Tuple [Tensor , Tensor , Dict [str , Any ]]:
271- """
272- Process the model output for calculating the loss.
273-
274- Args:
275- model_output (Dict[str, Tensor]): The output of the model.
276- labels (Tensor): The target labels.
277- loss_kwargs (Dict[str, Any]): Additional loss arguments.
278-
279- Returns:
280- tuple: A tuple containing the processed model output, labels, and loss arguments.
281- """
282- kwargs_copy = dict (loss_kwargs )
283- if labels is not None :
284- labels = labels .float ()
285- return model_output ["logits" ], labels , kwargs_copy
286-
287- def _get_prediction_and_labels (
288- self , data : Dict [str , Any ], labels : Tensor , model_output : Dict [str , Tensor ]
289- ) -> Tuple [Tensor , Tensor ]:
290- """
291- Get the predictions and labels from the model output. Applies a sigmoid to the model output.
292-
293- Args:
294- data (Dict[str, Any]): The input data.
295- labels (Tensor): The target labels.
296- model_output (Dict[str, Tensor]): The output of the model.
297-
298- Returns:
299- tuple: A tuple containing the predictions and labels.
300- """
301- d = model_output ["logits" ]
302- loss_kwargs = data .get ("loss_kwargs" , dict ())
303- if "non_null_labels" in loss_kwargs :
304- n = loss_kwargs ["non_null_labels" ]
305- d = d [n ]
306- return torch .sigmoid (d ), labels .int () if labels is not None else None
307-
308312 def forward (self , data : Dict [str , Tensor ], ** kwargs : Any ) -> Dict [str , Any ]:
309313 """
310314 Forward pass of the Electra model.
0 commit comments