-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Expand file tree
/
Copy pathtest_modeling_mllama.py
More file actions
470 lines (438 loc) · 15.4 KB
/
test_modeling_mllama.py
File metadata and controls
470 lines (438 loc) · 15.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
import re
import unittest
from copy import deepcopy
import pytest
import torch
from _torch.helpers import create_mock_cuda_graph_runner
from parameterized import parameterized
from test_modeling_llama import Scenario, reduce_llama_config
from transformers import MllamaConfig
from transformers import \
MllamaForConditionalGeneration as HFMllamaForConditionalGeneration
from utils.util import getSMVersion
import tensorrt_llm
from tensorrt_llm._torch.attention_backend.utils import get_attention_backend
from tensorrt_llm._torch.metadata import KVCacheParams
from tensorrt_llm._torch.model_config import ModelConfig
from tensorrt_llm._torch.models.modeling_mllama import \
MllamaForConditionalGeneration
from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager
from tensorrt_llm.bindings.executor import KvCacheConfig
from tensorrt_llm.mapping import Mapping
LLAMA_3_2_11B_VISION_CONFIG = {
'architectures': ['MllamaForConditionalGeneration'],
'image_token_index': 128256,
'model_type': 'mllama',
'text_config': {
'_name_or_path': '',
'add_cross_attention': False,
'architectures': None,
'bad_words_ids': None,
'begin_suppress_tokens': None,
'bos_token_id': 128000,
'chunk_size_feed_forward': 0,
'cross_attention_hidden_size': None,
'cross_attention_layers': [3, 8, 13, 18, 23, 28, 33, 38],
'decoder_start_token_id': None,
'diversity_penalty': 0.0,
'do_sample': False,
'dropout': 0,
'early_stopping': False,
'encoder_no_repeat_ngram_size': 0,
'eos_token_id': 128001,
'exponential_decay_length_penalty': None,
'finetuning_task': None,
'forced_bos_token_id': None,
'forced_eos_token_id': None,
'hidden_act': 'silu',
'hidden_size': 4096,
'id2label': {
'0': 'LABEL_0',
'1': 'LABEL_1'
},
'initializer_range': 0.02,
'intermediate_size': 14336,
'is_decoder': False,
'is_encoder_decoder': False,
'label2id': {
'LABEL_0': 0,
'LABEL_1': 1
},
'length_penalty': 1.0,
'max_length': 20,
'max_position_embeddings': 131072,
'min_length': 0,
'model_type': 'mllama_text_model',
'no_repeat_ngram_size': 0,
'num_attention_heads': 32,
'num_beam_groups': 1,
'num_beams': 1,
'num_hidden_layers': 40,
'num_key_value_heads': 8,
'num_return_sequences': 1,
'output_attentions': False,
'output_hidden_states': False,
'output_scores': False,
'pad_token_id': 128004,
'prefix': None,
'problem_type': None,
'pruned_heads': {},
'remove_invalid_values': False,
'repetition_penalty': 1.0,
'return_dict': True,
'return_dict_in_generate': False,
'rms_norm_eps': 1e-05,
'rope_scaling': {
'factor': 8.0,
'high_freq_factor': 4.0,
'low_freq_factor': 1.0,
'original_max_position_embeddings': 8192,
'rope_type': 'llama3'
},
'rope_theta': 500000.0,
'sep_token_id': None,
'suppress_tokens': None,
'task_specific_params': None,
'temperature': 1.0,
'tf_legacy_loss': False,
'tie_encoder_decoder': False,
'tie_word_embeddings': False,
'tokenizer_class': None,
'top_k': 50,
'top_p': 1.0,
'torch_dtype': 'bfloat16',
'torchscript': False,
'typical_p': 1.0,
'use_bfloat16': False,
'use_cache': True,
'vocab_size': 128256
},
'torch_dtype': 'bfloat16',
'transformers_version': '4.45.0.dev0',
'vision_config': {
'_name_or_path':
'',
'add_cross_attention':
False,
'architectures':
None,
'attention_heads':
16,
'bad_words_ids':
None,
'begin_suppress_tokens':
None,
'bos_token_id':
None,
'chunk_size_feed_forward':
0,
'cross_attention_hidden_size':
None,
'decoder_start_token_id':
None,
'diversity_penalty':
0.0,
'do_sample':
False,
'early_stopping':
False,
'encoder_no_repeat_ngram_size':
0,
'eos_token_id':
None,
'exponential_decay_length_penalty':
None,
'finetuning_task':
None,
'forced_bos_token_id':
None,
'forced_eos_token_id':
None,
'hidden_act':
'gelu',
'hidden_size':
1280,
'id2label': {
'0': 'LABEL_0',
'1': 'LABEL_1'
},
'image_size':
448,
'intermediate_layers_indices': [3, 7, 15, 23, 30],
'intermediate_size':
5120,
'is_decoder':
False,
'is_encoder_decoder':
False,
'label2id': {
'LABEL_0': 0,
'LABEL_1': 1
},
'length_penalty':
1.0,
'max_length':
20,
'max_num_tiles':
4,
'min_length':
0,
'model_type':
'mllama_vision_model',
'no_repeat_ngram_size':
0,
'norm_eps':
1e-05,
'num_beam_groups':
1,
'num_beams':
1,
'num_channels':
3,
'num_global_layers':
8,
'num_hidden_layers':
32,
'num_return_sequences':
1,
'output_attentions':
False,
'output_hidden_states':
False,
'output_scores':
False,
'pad_token_id':
None,
'patch_size':
14,
'prefix':
None,
'problem_type':
None,
'pruned_heads': {},
'remove_invalid_values':
False,
'repetition_penalty':
1.0,
'return_dict':
True,
'return_dict_in_generate':
False,
'sep_token_id':
None,
'supported_aspect_ratios': [[1, 1], [1, 2], [1, 3], [1, 4], [2, 1],
[2, 2], [3, 1], [4, 1]],
'suppress_tokens':
None,
'task_specific_params':
None,
'temperature':
1.0,
'tf_legacy_loss':
False,
'tie_encoder_decoder':
False,
'tie_word_embeddings':
True,
'tokenizer_class':
None,
'top_k':
50,
'top_p':
1.0,
'torch_dtype':
'bfloat16',
'torchscript':
False,
'typical_p':
1.0,
'use_bfloat16':
False,
'vision_output_dim':
7680
}
}
def convert_weights_names(weights: dict) -> dict:
# Since transformers version >= 4.52.0, the default model architecture is changed.
# We need to convert the weight names accordingly to match TRTLLM naming.
_checkpoint_conversion_mapping = {
"^model.language_model": "language_model.model",
"^model.vision_model": "vision_model",
"^model.multi_modal_projector": "multi_modal_projector",
"^lm_head": "language_model.lm_head",
}
converted_weights = {}
for weight_name, weight_value in weights.items():
new_name = weight_name
for pattern, replacement in _checkpoint_conversion_mapping.items():
new_name = re.sub(pattern, replacement, new_name)
converted_weights[new_name] = weight_value
return converted_weights
class TestMLlama(unittest.TestCase):
@parameterized.expand([
Scenario(backend="VANILLA"),
Scenario(backend="FLASHINFER"),
Scenario(backend="FLASHINFER", use_cuda_graph=True),
Scenario(backend="TRTLLM"),
Scenario(backend="TRTLLM", use_cuda_graph=True),
], lambda testcase_func, param_num, param:
f"{testcase_func.__name__}[{param.args[0]}]")
@torch.no_grad()
def test_mllama_allclose_to_hf_text_only(self, scenario: Scenario) -> None:
"""
Compare output to HF
"""
if scenario.backend == "FLASHINFER":
pytest.skip("https://nvbugspro.nvidia.com/bug/5458945")
backend = scenario.backend
metadata_cls = get_attention_backend(backend).Metadata
torch.random.manual_seed(0)
config_dict = deepcopy(LLAMA_3_2_11B_VISION_CONFIG)
dtype = MllamaConfig.from_dict(config_dict['text_config']).torch_dtype
dtype_bytes = dtype.itemsize
# 11B * sizeof(float16) plus some extra for activations (1.3x approx).
# MLllama also have vision encoder part. Just use 11B as upper bound.
activation_factor = 1.3
model_params = 11 * (10**9)
mem_for_full_model = 2 * model_params * dtype_bytes * activation_factor
reduce_llama_config(mem_for_full_model, config_dict['text_config'], 8)
if config_dict['text_config']['num_hidden_layers'] <= 0:
self.skipTest('Insufficient memory for a single Llama layer')
mllama_config = MllamaConfig.from_dict(config_dict)
# For text path only, downscale vision encoder to only 1 layer.
config_dict['vision_config']['num_hidden_layers'] = 1
device = torch.device('cuda')
hf_mllama = HFMllamaForConditionalGeneration(mllama_config).to(
dtype).to(device).eval()
mllama = MllamaForConditionalGeneration(
ModelConfig(pretrained_config=mllama_config,
attn_backend=backend)).to(dtype).to(device)
weights = convert_weights_names(hf_mllama.state_dict())
mllama.load_weights(weights)
# KV cache setup
num_blocks = 1
tokens_per_block = 128
head_dim = mllama.config.hidden_size // mllama.config.num_attention_heads
num_layers = mllama.config.num_hidden_layers
num_kv_heads = mllama.config.num_key_value_heads
max_seq_len = num_blocks * tokens_per_block
batch_size = 1
if dtype == torch.half:
kv_cache_dtype = tensorrt_llm.bindings.DataType.HALF
elif dtype == torch.bfloat16:
kv_cache_dtype = tensorrt_llm.bindings.DataType.BF16
else:
raise ValueError("Invalid dtype")
mapping = Mapping(world_size=1, tp_size=1, rank=0)
kv_cache_config = KvCacheConfig(max_tokens=num_blocks *
tokens_per_block)
kv_cache_manager = KVCacheManager(
kv_cache_config,
tensorrt_llm.bindings.internal.batch_manager.CacheType.SELF,
num_layers=num_layers,
num_kv_heads=num_kv_heads,
head_dim=head_dim,
tokens_per_block=tokens_per_block,
max_seq_len=max_seq_len,
max_batch_size=batch_size,
mapping=mapping,
dtype=kv_cache_dtype,
)
# context
input_ids = torch.tensor([100, 200, 300, 100, 200, 100, 400, 500],
dtype=torch.int,
device=device)
num_cached_tokens_per_seq = [0]
request_ids = [1]
token_nums = [input_ids.size(-1)]
prompt_lens = [input_ids.size(-1)]
kv_cache_manager.add_dummy_requests(request_ids, token_nums)
attn_metadata = metadata_cls(
seq_lens=torch.tensor([input_ids.size(-1)], dtype=torch.int),
num_contexts=1,
kv_cache_params=KVCacheParams(
use_cache=True,
num_cached_tokens_per_seq=num_cached_tokens_per_seq),
max_num_requests=1,
max_num_tokens=8192,
kv_cache_manager=kv_cache_manager,
request_ids=request_ids,
prompt_lens=prompt_lens)
# Note: no CUDA graphs for prefill, the graph runner is built for
# decoding only.
position_ids = [torch.arange(0, input_ids.size(-1))]
position_ids = torch.cat(position_ids).unsqueeze(0).cuda()
with torch.inference_mode():
attn_metadata.prepare()
logits = mllama.forward(input_ids=input_ids,
position_ids=position_ids,
attn_metadata=attn_metadata)
ref = hf_mllama.forward(input_ids=input_ids.unsqueeze(0),
position_ids=position_ids,
use_cache=True)
atol = 0.35 if getSMVersion() >= 121 else 0.3
torch.testing.assert_close(logits,
ref.logits[:, -1].float(),
atol=atol,
rtol=0.3)
# gen
gen_input_ids = torch.tensor([600], dtype=torch.int, device=device)
num_cached_tokens_per_seq = [input_ids.size(-1)]
attn_metadata = metadata_cls(
seq_lens=torch.tensor([gen_input_ids.size(-1)], dtype=torch.int),
num_contexts=0,
kv_cache_params=KVCacheParams(
use_cache=True,
num_cached_tokens_per_seq=num_cached_tokens_per_seq),
max_num_requests=1,
max_num_tokens=8192,
kv_cache_manager=kv_cache_manager,
request_ids=request_ids,
prompt_lens=prompt_lens)
gen_position_ids = [
torch.arange(input_ids.size(-1),
input_ids.size(-1) + gen_input_ids.size(-1))
]
gen_position_ids = torch.cat(gen_position_ids).unsqueeze(0).cuda()
graph_runner = create_mock_cuda_graph_runner(
1) if scenario.use_cuda_graph else None
def run_forward(input_ids, position_ids, attn_metadata):
attn_metadata.prepare()
if not scenario.use_cuda_graph:
return mllama.forward(input_ids=input_ids,
position_ids=position_ids,
attn_metadata=attn_metadata)
else:
inputs = {
"input_ids": input_ids,
"position_ids": position_ids,
"attn_metadata": attn_metadata,
}
key = (1, 0, False)
graph_runner.capture(key,
lambda inputs: mllama.forward(**inputs),
inputs)
for _ in range(2):
# Run it twice. This helps us catch problems if buffers are accidentally reallocated
# in prepare().
attn_metadata.prepare()
logits = graph_runner.replay(key, inputs)
return logits
if scenario.use_cuda_graph:
attn_metadata = attn_metadata.create_cuda_graph_metadata(1)
with torch.inference_mode():
logits = run_forward(input_ids=gen_input_ids,
position_ids=gen_position_ids,
attn_metadata=attn_metadata)
ref = hf_mllama.forward(input_ids=gen_input_ids.unsqueeze(0),
position_ids=gen_position_ids,
past_key_values=ref.past_key_values,
use_cache=True)
atol = 0.35 if getSMVersion() >= 121 else 0.3
torch.testing.assert_close(logits,
ref.logits[:, -1].float(),
atol=atol,
rtol=0.3)
if graph_runner is not None:
graph_runner.clear()
kv_cache_manager.shutdown()