22import torch
33import math
44
5+ from itertools import product
6+
7+ import transformers
58from transformers import (
69 AutoConfig ,
710 AutoModelForCausalLM ,
1114 set_seed ,
1215
1316)
14- import transformers
17+
1518
1619
1720def get_4bit_config ():
@@ -26,15 +29,23 @@ def get_4bit_config():
2629 )
2730
2831
29- def get_model (model_name_or_path = 'huggyllama/llama-7b' , bnb_config = get_4bit_config ()):
30- model = AutoModelForCausalLM .from_pretrained (
31- model_name_or_path ,
32- quantization_config = bnb_config ,
33- max_memory = {0 :'48GB' },
34- device_map = 'auto'
35- ).eval ()
32+ def get_model_and_tokenizer (config ):
33+ model_name_or_path , quant_type = config
34+ bnb_config = get_4bit_config ()
35+ if quant_type == '16bit' :
36+ bnb_config .load_in_4bit = False
37+ else :
38+ bnb_config .bnb_4bit_quant_type = quant_type
39+ model = AutoModelForCausalLM .from_pretrained (model_name_or_path ,
40+ quantization_config = bnb_config ,
41+ max_memory = {0 :'48GB' },
42+ device_map = 'auto' ,
43+ torch_dtype = torch .bfloat16
44+ ).eval ()
3645
37- return model
46+ tokenizer = transformers .AutoTokenizer .from_pretrained (model_name_or_path )
47+
48+ return model , tokenizer
3849
3950def get_prompt_for_generation_eval (text , add_roles = True ):
4051 description = (
@@ -53,48 +64,66 @@ def generate(model, tokenizer, text, generation_config, prompt_func=get_prompt_f
5364 outputs = model .generate (inputs = inputs ['input_ids' ], generation_config = generation_config )
5465 return tokenizer .decode (outputs [0 ], skip_special_tokens = True )
5566
56- name_or_path = 'huggyllama/llama-7b'
57- #name_or_path = 'AI-Sweden/gpt-sw3-126m'
58-
59- @pytest .fixture (scope = 'session' )
60- def model ():
61- bnb_config = get_4bit_config ()
62- bnb_config .bnb_4bit_compute_dtype = torch .float32
63- bnb_config .load_in_4bit = True
64- model = get_model (name_or_path )
65- print ('' )
66- return model
67-
68- @pytest .fixture (scope = 'session' )
69- def tokenizer ():
70- tokenizer = transformers .AutoTokenizer .from_pretrained (name_or_path )
71- return tokenizer
72-
67+ models = ['huggyllama/llama-7b' , 'bigscience/bloom-1b7' ]
68+ dtypes = ['nf4' , 'fp4' , '16bit' ]
69+ load_in_4bit = [True , False ]
70+ values = list (product (models , dtypes ))
71+ strfunc = lambda lst : [str (x ) for x in lst ]
72+ ids = ['_' .join (strfunc (x )) for x in values ]
73+ @pytest .fixture (scope = 'session' , params = values , ids = ids )
74+ def model_and_tokenizer (request ):
75+ model , tokenizer = get_model_and_tokenizer (request .param )
76+ yield model , tokenizer
77+ del model
78+
79+ @pytest .mark .parametrize ("inference_kernel" , [True , False ], ids = ['inference_kernel_True' , 'inference_kernel_False' ])
7380@pytest .mark .parametrize ("dtype" , [torch .float16 , torch .bfloat16 , torch .float32 ], ids = ['fp16' , 'bf16' , 'fp32' ])
74- def test_pi (model , tokenizer , dtype ):
81+ def test_pi (model_and_tokenizer , dtype , inference_kernel ):
82+
83+ model , tokenizer = model_and_tokenizer
7584
7685 generation_config = transformers .GenerationConfig (
77- max_new_tokens = 128 ,
86+ max_new_tokens = 20 ,
7887 do_sample = True ,
7988 top_p = 0.9 ,
8089 temperature = 0.7 ,
8190 )
82- generation_config .max_new_tokens = 50
91+ generation_config .max_new_tokens = 20
8392
8493
8594 #text = 'Please write down the first 50 digits of pi.'
8695 #text = get_prompt_for_generation_eval(text)
8796 #text += ' Sure, here the first 50 digits of pi: 3.14159'
97+ n_cases = 3
8898 text = '3.14159'
89- model .config .quantization_config .bnb_4bit_compute_dtype = dtype
99+ if hasattr (model .config , 'quantization_config' ):
100+ model .config .quantization_config .bnb_4bit_compute_dtype = dtype
90101
102+ if not inference_kernel :
103+ text = [text ]* n_cases
91104 inputs = tokenizer (text , return_tensors = "pt" ).to ('cuda:0' )
92- outputs = model .generate (inputs = inputs ['input_ids' ], generation_config = generation_config )
93- textout = tokenizer .decode (outputs [0 ], skip_special_tokens = True )
94- print ('' )
95- print (textout )
96- print (math .pi )
105+ x = inputs ['input_ids' ]
106+ failure_count = 0
107+ outputs = []
108+ if inference_kernel :
109+ for i in range (n_cases ):
110+ output = model .generate (x , generation_config = generation_config )
111+ textout = tokenizer .decode (output [0 ], skip_special_tokens = True )
112+ outputs .append (textout )
113+ else :
114+ outputs = model .generate (x , generation_config = generation_config )
115+ outputs = [tokenizer .decode (output , skip_special_tokens = True ) for output in outputs ]
116+
117+
118+ assert len (outputs ) == n_cases
119+ for i in range (n_cases ):
120+ if not outputs [i ][:len (str (math .pi ))] == str (math .pi ):
121+ failure_count += 1
122+ if failure_count > 1 :
123+ print (math .pi )
124+ for out in outputs :
125+ print (out )
126+ raise ValueError (f'Failure count: { failure_count } /{ n_cases } ' )
97127
98- assert textout [:len (str (math .pi ))] == str (math .pi )
99128
100129
0 commit comments