@@ -15,10 +15,7 @@ def phi3_model(selected_model, selected_model_name):
1515
1616@pytest .fixture (scope = "module" )
1717def llama3_model (selected_model , selected_model_name ):
18- if (
19- selected_model_name in ["transformers_llama3cpu_8b" ]
20- and selected_model is not None
21- ):
18+ if selected_model_name in ["transformers_llama3cpu_8b" ] and selected_model is not None :
2219 return selected_model
2320 else :
2421 pytest .skip ("Requires Llama3 model (needs HF_TOKEN to be set)" )
@@ -27,7 +24,7 @@ def llama3_model(selected_model, selected_model_name):
2724def test_gpt2 ():
2825 gpt2 = get_model ("transformers:gpt2" )
2926 lm = gpt2 + "this is a test" + gen ("test" , max_tokens = 10 )
30-
27+
3128 assert len (str (lm )) > len ("this is a test" )
3229
3330
@@ -42,9 +39,7 @@ def test_recursion_error():
4239 { gen ('verse' , max_tokens = 2 )}
4340 """
4441 )
45- assert len (str (lm )) > len (
46- "Tweak this proverb to apply to model instructions instead.\n \n "
47- )
42+ assert len (str (lm )) > len ("Tweak this proverb to apply to model instructions instead.\n \n " )
4843
4944
5045TRANSFORMER_MODELS = {
@@ -81,6 +76,7 @@ def test_transformer_smoke_select(model_name, model_kwargs):
8176
8277# Phi-3 specific tests
8378
79+
8480@pytest .mark .skip ("Don't overload the build machines" )
8581def test_phi3_transformers_orig ():
8682 import torch
@@ -116,11 +112,10 @@ def test_phi3_transformers_orig():
116112def test_phi3_chat_basic (phi3_model : models .Model ):
117113 lm = phi3_model
118114
119- lm += "You are a counting bot. Just keep counting numbers."
120115 with user ():
121- lm += "1,2,3,4 "
116+ lm += "You are a counting bot. Just keep counting numbers. "
122117 with assistant ():
123- lm += gen (name = "five" , max_tokens = 10 )
118+ lm += "1,2,3,4," + gen (name = "five" , max_tokens = 20 )
124119
125120 assert "5" in lm ["five" ]
126121
@@ -143,7 +138,7 @@ def test_phi3_newline_chat(phi3_model: models.Model):
143138 with assistant ():
144139 lm += "\n " + gen (name = "five" , max_tokens = 1 )
145140 lm += "\n " + gen (name = "six" , max_tokens = 1 )
146-
141+
147142 # This test would raise an exception earlier if we didn't fix the tokenizer.
148143 assert True
149144
@@ -155,7 +150,7 @@ def test_phi3_unstable_tokenization(phi3_model: models.Model):
155150 with user ():
156151 lm += "1,2,3,4,"
157152 with assistant ():
158- lm += "\n " # comment and uncomment this line to get the error
153+ lm += "\n " # comment and uncomment this line to get the error
159154 lm += gen (name = "five" , max_tokens = 1 )
160155 lm += "," + gen (name = "six" , max_tokens = 1 )
161156
@@ -168,4 +163,4 @@ def test_phi3_basic_completion_badtokens(phi3_model: models.Model):
168163 lm += f"""<|use\n \n You are a counting bot. Just keep counting numbers.<|end|><|assistant|>1,2,3,4,"""
169164 lm += gen ("five" , max_tokens = 10 )
170165
171- assert len (lm ["five" ]) > 0
166+ assert len (lm ["five" ]) > 0
0 commit comments