11from dataclasses import dataclass
2- from turtle import hideturtle
2+
3+ import torch
4+ from torch import nn
35from transformers import (
6+ AutoConfig ,
7+ AutoModelForSequenceClassification ,
48 GPTNeoXConfig ,
5- GPTNeoXPreTrainedModel ,
69 GPTNeoXModel ,
7- AutoModelForSequenceClassification ,
8- AutoConfig ,
10+ GPTNeoXPreTrainedModel ,
911)
10- from torch import nn
11- import torch
1212from transformers .utils import ModelOutput
1313
1414
@@ -20,11 +20,13 @@ class GPTNeoxRMOuptput(ModelOutput):
2020
2121 logits : torch .FloatTensor = None
2222
23+
2324class GPTNeoXConfigRM (GPTNeoXConfig ):
2425 model_type = "rm_gptneox_config"
26+
2527 def __init__ (
2628 self ,
27- pooling = "last" ,
29+ pooling = "last" ,
2830 ** kwargs ,
2931 ):
3032 super ().__init__ (** kwargs )
@@ -33,7 +35,7 @@ def __init__(
3335
3436class GPTNeoXRM (GPTNeoXPreTrainedModel ):
3537 config_class = GPTNeoXConfigRM
36- """
38+ """
3739 Reward Model
3840 """
3941
@@ -44,7 +46,9 @@ def __init__(
4446 super ().__init__ (config )
4547 self .gpt_neox = GPTNeoXModel (config )
4648 self .pooling = config .pooling
47- hidden_size = config .hidden_size if self .pooling != "mean-max" else config .hidden_size * 2
49+ hidden_size = (
50+ config .hidden_size if self .pooling != "mean-max" else config .hidden_size * 2
51+ )
4852 self .out_layer = nn .Linear (hidden_size , 1 )
4953
5054 def forward (
@@ -74,22 +78,20 @@ def forward(
7478 ) / attention_mask .sum (dim = 1 ).unsqueeze (- 1 )
7579 elif self .pooling == "last" :
7680 if attention_mask is None :
77- hidden_states = hidden_states [:,- 1 ,:]
81+ hidden_states = hidden_states [:, - 1 , :]
7882 else :
7983 last_idx = attention_mask .cumsum (1 ).argmax (1 )
80- last_idx = last_idx .view (- 1 ,1 , 1 ).expand (- 1 ,1 , hidden_states .size (- 1 ))
81- hidden_states = torch .gather (hidden_states ,1 , last_idx ).squeeze (1 )
84+ last_idx = last_idx .view (- 1 , 1 , 1 ).expand (- 1 , 1 , hidden_states .size (- 1 ))
85+ hidden_states = torch .gather (hidden_states , 1 , last_idx ).squeeze (1 )
8286 elif self .pooling == "mean-max" :
8387 if attention_mask is None :
8488 mean , max = hidden_states .mean (dim = 1 ), hidden_states .max (dim = 1 ).values
85- hidden_states = torch .cat ([mean ,max ],1 )
89+ hidden_states = torch .cat ([mean , max ], 1 )
8690 else :
8791 mean = (hidden_states * attention_mask .unsqueeze (- 1 )).sum (
8892 dim = 1
8993 ) / attention_mask .sum (dim = 1 ).unsqueeze (- 1 )
90- max = (hidden_states * attention_mask .unsqueeze (- 1 )).max (
91- dim = 1
92- ).values
94+ max = (hidden_states * attention_mask .unsqueeze (- 1 )).max (dim = 1 ).values
9395 hidden_states = torch .cat ([mean , max ], 1 )
9496 else :
9597 raise ValueError (f"invalid pooling { self .pooling } " )
@@ -103,4 +105,4 @@ def forward(
103105
104106
105107AutoConfig .register ("rm_gptneox_config" , GPTNeoXConfigRM )
106- AutoModelForSequenceClassification .register (GPTNeoXConfigRM , GPTNeoXRM )
108+ AutoModelForSequenceClassification .register (GPTNeoXConfigRM , GPTNeoXRM )
0 commit comments