@@ -156,31 +156,29 @@ def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch":
156
156
past_keys = past_keys .view (batch .size , - 1 , * past_keys .shape [- 2 :])
157
157
past_values = past_values .view (batch .size , - 1 , * past_values .shape [- 2 :])
158
158
159
- _ , num_heads , head_dim , padded_sequence_length = past_keys .shape
159
+ _ , num_heads , padded_sequence_length , head_dim = past_values .shape
160
160
161
- padded_past_keys_shape = (
161
+ padded_past_values_shape = (
162
162
total_batch_size ,
163
163
num_heads ,
164
- head_dim ,
165
164
max_sequence_length - 1 ,
165
+ head_dim ,
166
166
)
167
167
168
- # head_dim is last for BLOOM
169
- if past_values .shape [- 1 ] == head_dim :
170
- past_values_head_dim_last = True
171
- padded_past_values_shape = (
168
+ # seq_length is last for BLOOM
169
+ if past_keys .shape [- 2 ] == head_dim :
170
+ past_keys_head_dim_last = False
171
+ padded_past_keys_shape = (
172
172
total_batch_size ,
173
173
num_heads ,
174
- max_sequence_length - 1 ,
175
174
head_dim ,
175
+ max_sequence_length - 1 ,
176
176
)
177
- elif past_values .shape [- 2 ] == head_dim :
178
- past_values_head_dim_last = False
179
- padded_past_values_shape = padded_past_keys_shape
177
+ elif past_keys .shape [- 1 ] == head_dim :
178
+ past_keys_head_dim_last = True
179
+ padded_past_keys_shape = padded_past_values_shape
180
180
else :
181
- raise ValueError (
182
- f"past_values shape { past_values .shape } is not valid"
183
- )
181
+ raise ValueError (f"past_keys shape { past_keys .shape } is not valid" )
184
182
185
183
# This will run only once per layer
186
184
if j == len (past_key_values ):
@@ -197,24 +195,24 @@ def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch":
197
195
past_key_values .append ((padded_past_keys , padded_past_values ))
198
196
199
197
# We slice the past keys and values to remove the padding from previous batches
200
- past_key_values [j ][0 ][
201
- start_index :end_index , :, :, - (batch .max_sequence_length - 1 ) :
202
- ] = past_keys [:, :, :, - (batch .max_sequence_length - 1 ) :]
203
-
204
- if past_values_head_dim_last :
205
- past_key_values [j ][1 ][
198
+ if past_keys_head_dim_last :
199
+ past_key_values [j ][0 ][
206
200
start_index :end_index ,
207
201
:,
208
202
- (batch .max_sequence_length - 1 ) :,
209
203
:,
210
- ] = past_values [:, :, - (batch .max_sequence_length - 1 ) :, :]
204
+ ] = past_keys [:, :, - (batch .max_sequence_length - 1 ) :, :]
211
205
else :
212
- past_key_values [j ][1 ][
206
+ past_key_values [j ][0 ][
213
207
start_index :end_index ,
214
208
:,
215
209
:,
216
210
- (batch .max_sequence_length - 1 ) :,
217
- ] = past_values [:, :, :, - (batch .max_sequence_length - 1 ) :]
211
+ ] = past_keys [:, :, :, - (batch .max_sequence_length - 1 ) :]
212
+
213
+ past_key_values [j ][1 ][
214
+ start_index :end_index , :, - (batch .max_sequence_length - 1 ) :, :
215
+ ] = past_values [:, :, - (batch .max_sequence_length - 1 ) :, :]
218
216
219
217
start_index += batch .size
220
218
@@ -243,13 +241,13 @@ def __init__(self, model_name: str, quantize=False):
243
241
dtype = torch .float32
244
242
245
243
tokenizer = AutoTokenizer .from_pretrained (model_name , padding_side = "left" )
246
- tokenizer .add_special_tokens ({"pad_token" : "[PAD]" })
247
244
self .model = AutoModelForCausalLM .from_pretrained (
248
245
model_name ,
249
246
torch_dtype = dtype ,
250
247
device_map = "auto" if torch .cuda .is_available () else None ,
251
248
load_in_8bit = quantize ,
252
249
).eval ()
250
+ tokenizer .pad_token_id = self .model .config .pad_token_id
253
251
254
252
super (CausalLM , self ).__init__ (
255
253
tokenizer = tokenizer ,
0 commit comments