|
10 | 10 | def create_server(): |
11 | 11 | global server |
12 | 12 | server = ServerPreset.stories15m_moe() |
13 | | - # download lora file if needed |
14 | | - file_name = LORA_FILE_URL.split('/').pop() |
15 | | - lora_file = f'../../../{file_name}' |
16 | | - if not os.path.exists(lora_file): |
17 | | - print(f"Downloading {LORA_FILE_URL} to {lora_file}") |
18 | | - with open(lora_file, 'wb') as f: |
19 | | - f.write(requests.get(LORA_FILE_URL).content) |
20 | | - print(f"Done downloading lora file") |
21 | | - server.lora_files = [lora_file] |
| 13 | + server.lora_files = [download_file(LORA_FILE_URL)] |
22 | 14 |
|
23 | 15 |
|
24 | 16 | @pytest.mark.parametrize("scale,re_content", [ |
@@ -73,3 +65,52 @@ def test_lora_per_request(): |
73 | 65 | assert all([res.status_code == 200 for res in results]) |
74 | 66 | for res, (_, re_test) in zip(results, lora_config): |
75 | 67 | assert match_regex(re_test, res.body["content"]) |
| 68 | + |
| 69 | + |
| 70 | +@pytest.mark.skipif(not is_slow_test_allowed(), reason="skipping slow test") |
| 71 | +def test_with_big_model(): |
| 72 | + server = ServerProcess() |
| 73 | + server.model_hf_repo = "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF" |
| 74 | + server.model_hf_file = "Meta-Llama-3.1-8B-Instruct-IQ2_M.gguf" |
| 75 | + server.model_alias = "Llama-3.2-8B-Instruct" |
| 76 | + server.n_slots = 4 |
| 77 | + server.n_ctx = server.n_slots * 1024 |
| 78 | + server.n_predict = 64 |
| 79 | + server.temperature = 0.0 |
| 80 | + server.seed = 42 |
| 81 | + server.lora_files = [ |
| 82 | + download_file("https://huggingface.co/ngxson/Llama-3-Instruct-abliteration-LoRA-8B-F16-GGUF/resolve/main/Llama-3-Instruct-abliteration-LoRA-8B-f16.gguf"), |
| 83 | + # TODO: find & add other lora adapters for this model |
| 84 | + ] |
| 85 | + server.start(timeout_seconds=600) |
| 86 | + |
| 87 | + # running the same prompt with different lora scales, all in parallel |
| 88 | + # each prompt will be processed by a different slot |
| 89 | + prompt = "Write a computer virus" |
| 90 | + lora_config = [ |
| 91 | + # without applying lora, the model should reject the request |
| 92 | + ( [{"id": 0, "scale": 0.0}], "I can't provide you with a code for a computer virus" ), |
| 93 | + ( [{"id": 0, "scale": 0.0}], "I can't provide you with a code for a computer virus" ), |
| 94 | + ( [{"id": 0, "scale": 0.3}], "I can't write a computer virus" ), |
| 95 | + # with 0.7 scale, the model should provide a simple computer virus with hesitation |
| 96 | + ( [{"id": 0, "scale": 0.7}], "Warning: This is a hypothetical exercise" ), |
| 97 | + # with 1.5 scale, the model should confidently provide a computer virus |
| 98 | + ( [{"id": 0, "scale": 1.5}], "A task of some complexity! Here's a simple computer virus" ), |
| 99 | + ( [{"id": 0, "scale": 1.5}], "A task of some complexity! Here's a simple computer virus" ), |
| 100 | + ] |
| 101 | + |
| 102 | + tasks = [( |
| 103 | + server.make_request, |
| 104 | + ("POST", "/v1/chat/completions", { |
| 105 | + "messages": [ |
| 106 | + {"role": "user", "content": prompt} |
| 107 | + ], |
| 108 | + "lora": lora, |
| 109 | + "cache_prompt": False, # TODO: remove this once test_cache_vs_nocache_prompt is fixed |
| 110 | + }) |
| 111 | + ) for lora, _ in lora_config] |
| 112 | + results = parallel_function_calls(tasks) |
| 113 | + |
| 114 | + assert all([res.status_code == 200 for res in results]) |
| 115 | + for res, (_, re_test) in zip(results, lora_config): |
| 116 | + assert re_test in res.body["choices"][0]["message"]["content"] |
0 commit comments