Skip to content

Commit dc96e9e

Browse files
committed
Test for bloom that fails with inference kernels.
1 parent ae7cd6a commit dc96e9e

File tree

1 file changed

+65
-36
lines changed

1 file changed

+65
-36
lines changed

tests/test_generation.py

Lines changed: 65 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22
import torch
33
import math
44

5+
from itertools import product
6+
7+
import transformers
58
from transformers import (
69
AutoConfig,
710
AutoModelForCausalLM,
@@ -11,7 +14,7 @@
1114
set_seed,
1215

1316
)
14-
import transformers
17+
1518

1619

1720
def get_4bit_config():
@@ -26,15 +29,23 @@ def get_4bit_config():
2629
)
2730

2831

29-
def get_model(model_name_or_path='huggyllama/llama-7b', bnb_config=get_4bit_config()):
30-
model = AutoModelForCausalLM.from_pretrained(
31-
model_name_or_path,
32-
quantization_config=bnb_config,
33-
max_memory={0:'48GB'},
34-
device_map='auto'
35-
).eval()
32+
def get_model_and_tokenizer(config):
33+
model_name_or_path, quant_type = config
34+
bnb_config = get_4bit_config()
35+
if quant_type == '16bit':
36+
bnb_config.load_in_4bit = False
37+
else:
38+
bnb_config.bnb_4bit_quant_type= quant_type
39+
model = AutoModelForCausalLM.from_pretrained(model_name_or_path,
40+
quantization_config=bnb_config,
41+
max_memory={0:'48GB'},
42+
device_map='auto',
43+
torch_dtype=torch.bfloat16
44+
).eval()
3645

37-
return model
46+
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name_or_path)
47+
48+
return model, tokenizer
3849

3950
def get_prompt_for_generation_eval(text, add_roles=True):
4051
description = (
@@ -53,48 +64,66 @@ def generate(model, tokenizer, text, generation_config, prompt_func=get_prompt_f
5364
outputs = model.generate(inputs=inputs['input_ids'], generation_config=generation_config)
5465
return tokenizer.decode(outputs[0], skip_special_tokens=True)
5566

56-
name_or_path = 'huggyllama/llama-7b'
57-
#name_or_path = 'AI-Sweden/gpt-sw3-126m'
58-
59-
@pytest.fixture(scope='session')
60-
def model():
61-
bnb_config = get_4bit_config()
62-
bnb_config.bnb_4bit_compute_dtype=torch.float32
63-
bnb_config.load_in_4bit=True
64-
model = get_model(name_or_path)
65-
print('')
66-
return model
67-
68-
@pytest.fixture(scope='session')
69-
def tokenizer():
70-
tokenizer = transformers.AutoTokenizer.from_pretrained(name_or_path)
71-
return tokenizer
72-
67+
models = ['huggyllama/llama-7b', 'bigscience/bloom-1b7']
68+
dtypes = ['nf4', 'fp4', '16bit']
69+
load_in_4bit = [True, False]
70+
values = list(product(models, dtypes))
71+
strfunc = lambda lst: [str(x) for x in lst]
72+
ids = ['_'.join(strfunc(x)) for x in values]
73+
@pytest.fixture(scope='session', params=values, ids=ids)
74+
def model_and_tokenizer(request):
75+
model, tokenizer = get_model_and_tokenizer(request.param)
76+
yield model, tokenizer
77+
del model
78+
79+
@pytest.mark.parametrize("inference_kernel", [True, False], ids=['inference_kernel_True', 'inference_kernel_False'])
7380
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=['fp16', 'bf16', 'fp32'])
74-
def test_pi(model, tokenizer, dtype):
81+
def test_pi(model_and_tokenizer, dtype, inference_kernel):
82+
83+
model, tokenizer = model_and_tokenizer
7584

7685
generation_config = transformers.GenerationConfig(
77-
max_new_tokens=128,
86+
max_new_tokens=20,
7887
do_sample=True,
7988
top_p=0.9,
8089
temperature=0.7,
8190
)
82-
generation_config.max_new_tokens = 50
91+
generation_config.max_new_tokens = 20
8392

8493

8594
#text = 'Please write down the first 50 digits of pi.'
8695
#text = get_prompt_for_generation_eval(text)
8796
#text += ' Sure, here the first 50 digits of pi: 3.14159'
97+
n_cases = 3
8898
text = '3.14159'
89-
model.config.quantization_config.bnb_4bit_compute_dtype = dtype
99+
if hasattr(model.config, 'quantization_config'):
100+
model.config.quantization_config.bnb_4bit_compute_dtype = dtype
90101

102+
if not inference_kernel:
103+
text = [text]*n_cases
91104
inputs = tokenizer(text, return_tensors="pt").to('cuda:0')
92-
outputs = model.generate(inputs=inputs['input_ids'], generation_config=generation_config)
93-
textout = tokenizer.decode(outputs[0], skip_special_tokens=True)
94-
print('')
95-
print(textout)
96-
print(math.pi)
105+
x = inputs['input_ids']
106+
failure_count = 0
107+
outputs = []
108+
if inference_kernel:
109+
for i in range(n_cases):
110+
output = model.generate(x, generation_config=generation_config)
111+
textout = tokenizer.decode(output[0], skip_special_tokens=True)
112+
outputs.append(textout)
113+
else:
114+
outputs = model.generate(x, generation_config=generation_config)
115+
outputs = [tokenizer.decode(output, skip_special_tokens=True) for output in outputs]
116+
117+
118+
assert len(outputs) == n_cases
119+
for i in range(n_cases):
120+
if not outputs[i][:len(str(math.pi))] == str(math.pi):
121+
failure_count += 1
122+
if failure_count > 1:
123+
print(math.pi)
124+
for out in outputs:
125+
print(out)
126+
raise ValueError(f'Failure count: {failure_count}/{n_cases}')
97127

98-
assert textout[:len(str(math.pi))] == str(math.pi)
99128

100129

0 commit comments

Comments
 (0)