@@ -61,12 +61,22 @@ def __init__(self, model_config: Dict[str, Any], generate_config: Dict[str, Any]
61
61
self .generate_config = generate_config .copy ()
62
62
self .generate_config .update (self .FORCE_GENERATE_CONFIG )
63
63
self .tokenizer = tokenizer
64
+ self .num_generations = 8
64
65
65
66
@torch .no_grad ()
66
67
def generate (self , input_ids : torch .Tensor , attention_mask : torch .Tensor , ** kwargs ) -> Dict [str , torch .Tensor ]:
68
+ micro_batch_size = input_ids .size (0 )
67
69
input_ids = input_ids .to (get_current_device ())
68
70
attention_mask = attention_mask .to (get_current_device ())
69
- out = self .model .generate (input_ids , attention_mask = attention_mask , ** kwargs , ** self .generate_config )
71
+ gt_answer = None
72
+ if "gt_answer" in kwargs :
73
+ gt_answer = kwargs .pop ("gt_answer" )
74
+ if self .num_generations > 1 :
75
+ input_ids = input_ids .repeat_interleave (self .num_generations , dim = 0 )
76
+ attention_mask = attention_mask .repeat_interleave (self .num_generations , dim = 0 )
77
+ out = self .model .generate (
78
+ input_ids , attention_mask = attention_mask , ** kwargs , ** self .generate_config , tokenizer = self .tokenizer
79
+ )
70
80
input_len = input_ids .shape [- 1 ]
71
81
new_token_ids = out .sequences [:, input_len :]
72
82
# get log probs
@@ -76,10 +86,13 @@ def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwar
76
86
action_log_probs .append (log_probs_from_logits (logits [:, None , :], new_token_ids [:, i : i + 1 ]))
77
87
action_log_probs = torch .cat (action_log_probs , dim = 1 )
78
88
# get action mask
89
+ response_idx = torch .zeros ((new_token_ids .size (0 ), 2 ), dtype = torch .int ).to (get_current_device ())
79
90
action_mask = torch .ones_like (new_token_ids , dtype = attention_mask .dtype )
80
91
if self .tokenizer .eos_token_id is not None :
81
92
for indices in torch .nonzero (new_token_ids == self .tokenizer .eos_token_id ):
82
93
action_mask [indices [0 ], indices [1 ] + 1 :] = 0
94
+ response_idx [:, 0 ] = input_len
95
+ response_idx [:, 1 ] = input_len + action_mask .sum (dim = 1 ) - 1
83
96
84
97
if attention_mask .size (0 ) != action_mask .size (0 ):
85
98
assert action_mask .size (0 ) % attention_mask .size (0 ) == 0
@@ -91,7 +104,15 @@ def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwar
91
104
"attention_mask" : attention_mask ,
92
105
"action_log_probs" : action_log_probs ,
93
106
"action_mask" : action_mask ,
107
+ "response_idx" : response_idx ,
94
108
}
109
+
110
+ data = {k : v .view (micro_batch_size , self .num_generations , v .size (- 1 )) for k , v in data .items ()}
111
+
112
+ if gt_answer is not None :
113
+ # repeat gt_answer for each prompt.
114
+ data ["gt_answer" ] = gt_answer .repeat_interleave (self .num_generations , dim = 1 )
115
+ data = {k : v .to (get_current_device ()) for k , v in data .items ()}
95
116
return data
96
117
97
118
def load_state_dict (self , state_dict : Dict [str , torch .Tensor ]) -> None :
0 commit comments