55
66import subprocess
77import os
8+ import re
9+ import json
810import sys
911import threading
1012import requests
1921 Sequence ,
2022 Set ,
2123)
24+ from re import RegexFlag
2225
2326
2427class ServerResponse :
@@ -34,6 +37,9 @@ class ServerProcess:
3437 server_host : str = "127.0.0.1"
3538 model_hf_repo : str = "ggml-org/models"
3639 model_hf_file : str = "tinyllamas/stories260K.gguf"
40+ model_alias : str = "tinyllama-2"
41+ temperature : float = 0.8
42+ seed : int = 42
3743
3844 # custom options
3945 model_alias : str | None = None
@@ -48,7 +54,6 @@ class ServerProcess:
4854 n_ga_w : int | None = None
4955 n_predict : int | None = None
5056 n_prompts : int | None = 0
51- n_server_predict : int | None = None
5257 slot_save_path : str | None = None
5358 id_slot : int | None = None
5459 cache_prompt : bool | None = None
@@ -58,12 +63,9 @@ class ServerProcess:
5863 server_embeddings : bool | None = False
5964 server_reranking : bool | None = False
6065 server_metrics : bool | None = False
61- seed : int | None = None
6266 draft : int | None = None
63- server_seed : int | None = None
6467 user_api_key : str | None = None
6568 response_format : str | None = None
66- temperature : float | None = None
6769 lora_file : str | None = None
6870 disable_ctx_shift : int | None = False
6971
@@ -86,6 +88,10 @@ def start(self, timeout_seconds: int = 10) -> None:
8688 self .server_host ,
8789 "--port" ,
8890 self .server_port ,
91+ "--temp" ,
92+ self .temperature ,
93+ "--seed" ,
94+ self .seed ,
8995 ]
9096 if self .model_file :
9197 server_args .extend (["--model" , self .model_file ])
@@ -119,8 +125,8 @@ def start(self, timeout_seconds: int = 10) -> None:
119125 server_args .extend (["--ctx-size" , self .n_ctx ])
120126 if self .n_slots :
121127 server_args .extend (["--parallel" , self .n_slots ])
122- if self .n_server_predict :
123- server_args .extend (["--n-predict" , self .n_server_predict ])
128+ if self .n_predict :
129+ server_args .extend (["--n-predict" , self .n_predict ])
124130 if self .slot_save_path :
125131 server_args .extend (["--slot-save-path" , self .slot_save_path ])
126132 if self .server_api_key :
@@ -216,12 +222,52 @@ def make_request(
216222 result .headers = dict (response .headers )
217223 result .status_code = response .status_code
218224 result .body = response .json ()
225+ print ("Response from server" , result .body )
219226 return result
227+
228+ def make_stream_request (
229+ self ,
230+ method : str ,
231+ path : str ,
232+ data : dict | None = None ,
233+ headers : dict | None = None ,
234+ ) -> Iterator [dict ]:
235+ url = f"http://{ self .server_host } :{ self .server_port } { path } "
236+ headers = {}
237+ if self .user_api_key :
238+ headers ["Authorization" ] = f"Bearer { self .user_api_key } "
239+ if method == "POST" :
240+ response = requests .post (url , headers = headers , json = data , stream = True )
241+ else :
242+ raise ValueError (f"Unimplemented method: { method } " )
243+ for line_bytes in response .iter_lines ():
244+ line = line_bytes .decode ("utf-8" )
245+ if '[DONE]' in line :
246+ break
247+ elif line .startswith ('data: ' ):
248+ data = json .loads (line [6 :])
249+ print ("Partial response from server" , data )
250+ yield data
220251
221252
222253server_instances : Set [ServerProcess ] = set ()
223254
224255
256+ class ServerPreset :
257+ @staticmethod
258+ def tinyllamas () -> ServerProcess :
259+ server = ServerProcess ()
260+ server .model_hf_repo = "ggml-org/models"
261+ server .model_hf_file = "tinyllamas/stories260K.gguf"
262+ server .model_alias = "tinyllama-2"
263+ server .n_ctx = 256
264+ server .n_batch = 32
265+ server .n_slots = 2
266+ server .n_predict = 64
267+ server .seed = 42
268+ return server
269+
270+
225271def multiple_post_requests (
226272 server : ServerProcess , path : str , data : Sequence [dict ], headers : dict | None = None
227273) -> Sequence [ServerResponse ]:
@@ -248,3 +294,12 @@ def thread_target(data_chunk):
248294 thread .join ()
249295
250296 return results
297+
298+
299+ def match_regex (regex : str , text : str ) -> bool :
300+ return (
301+ re .compile (
302+ regex , flags = RegexFlag .IGNORECASE | RegexFlag .MULTILINE | RegexFlag .DOTALL
303+ ).search (text )
304+ is not None
305+ )
0 commit comments