@@ -13,28 +13,28 @@ def test_infill_without_input_extra():
1313 global server
1414 server .start ()
1515 res = server .make_request ("POST" , "/infill" , data = {
16- "prompt " : "Complete this " ,
17- "input_prefix " : "#include <cstdio> \n #include \" llama.h \" \n \n int main() { \n int n_threads = llama_" ,
16+ "input_prefix " : "#include <cstdio> \n #include \" llama.h \" \n \n int main() { \n " ,
17+ "prompt " : " int n_threads = llama_" ,
1818 "input_suffix" : "}\n " ,
1919 })
2020 assert res .status_code == 200
21- assert match_regex ("(One|day|she|saw|big|scary|bird )+" , res .body ["content" ])
21+ assert match_regex ("(Ann|small|shiny )+" , res .body ["content" ])
2222
2323
2424def test_infill_with_input_extra ():
2525 global server
2626 server .start ()
2727 res = server .make_request ("POST" , "/infill" , data = {
28- "prompt" : "Complete this" ,
2928 "input_extra" : [{
3029 "filename" : "llama.h" ,
3130 "text" : "LLAMA_API int32_t llama_n_threads();\n "
3231 }],
33- "input_prefix" : "#include <cstdio>\n #include \" llama.h\" \n \n int main() {\n int n_threads = llama_" ,
32+ "input_prefix" : "#include <cstdio>\n #include \" llama.h\" \n \n int main() {\n " ,
33+ "prompt" : " int n_threads = llama_" ,
3434 "input_suffix" : "}\n " ,
3535 })
3636 assert res .status_code == 200
37- assert match_regex ("(cuts|Jimmy|mom|came|into|the|room )+" , res .body ["content" ])
37+ assert match_regex ("(Dad|excited|park )+" , res .body ["content" ])
3838
3939
4040@pytest .mark .parametrize ("input_extra" , [
@@ -48,10 +48,30 @@ def test_invalid_input_extra_req(input_extra):
4848 global server
4949 server .start ()
5050 res = server .make_request ("POST" , "/infill" , data = {
51- "prompt" : "Complete this" ,
5251 "input_extra" : [input_extra ],
53- "input_prefix" : "#include <cstdio>\n #include \" llama.h\" \n \n int main() {\n int n_threads = llama_" ,
52+ "input_prefix" : "#include <cstdio>\n #include \" llama.h\" \n \n int main() {\n " ,
53+ "prompt" : " int n_threads = llama_" ,
5454 "input_suffix" : "}\n " ,
5555 })
5656 assert res .status_code == 400
5757 assert "error" in res .body
58+
59+
60+ @pytest .mark .skipif (not is_slow_test_allowed (), reason = "skipping slow test" )
61+ def test_with_qwen_model ():
62+ global server
63+ server .model_file = None
64+ server .model_hf_repo = "ggml-org/Qwen2.5-Coder-1.5B-IQ3_XXS-GGUF"
65+ server .model_hf_file = "qwen2.5-coder-1.5b-iq3_xxs-imat.gguf"
66+ server .start (timeout_seconds = 600 )
67+ res = server .make_request ("POST" , "/infill" , data = {
68+ "input_extra" : [{
69+ "filename" : "llama.h" ,
70+ "text" : "LLAMA_API int32_t llama_n_threads();\n "
71+ }],
72+ "input_prefix" : "#include <cstdio>\n #include \" llama.h\" \n \n int main() {\n " ,
73+ "prompt" : " int n_threads = llama_" ,
74+ "input_suffix" : "}\n " ,
75+ })
76+ assert res .status_code == 200
77+ assert res .body ["content" ] == "n_threads();\n printf(\" Number of threads: %d\\ n\" , n_threads);\n return 0;\n "
0 commit comments