@@ -77,6 +77,7 @@ def step_server_config(context, server_fqdn: str, server_port: str):
7777 context .response_format = None
7878 context .temperature = None
7979 context .lora_file = None
80+ context .disable_ctx_shift = False
8081
8182 context .tasks_result = []
8283 context .concurrent_tasks = []
@@ -148,7 +149,7 @@ def step_n_slots(context, n_slots: int):
148149
149150@step ('{n_predict:d} server max tokens to predict' )
150151def step_server_n_predict (context , n_predict : int ):
151- context .n_server_predict = n_predict
152+ context .n_server_predict = n_predict if n_predict > 0 else None
152153
153154
154155@step ('{slot_save_path} as slot save path' )
@@ -180,6 +181,9 @@ def step_server_embeddings(context):
180181def step_server_metrics (context ):
181182 context .server_metrics = True
182183
184+ @step ('disable context shifting' )
185+ def step_server_metrics (context ):
186+ context .disable_ctx_shift = True
183187
184188@step ("the server is starting" )
185189def step_start_server (context ):
@@ -257,7 +261,7 @@ async def step_all_slots_status(context, expected_slot_status_string: Literal['i
257261@step ('a completion request with {api_error} api error' )
258262@async_run_until_complete
259263async def step_request_completion (context , api_error : Literal ['raised' ] | str ):
260- expect_api_error = api_error == 'raised'
264+ expect_api_error = api_error == 'raised' or api_error != 'no'
261265 seeds = await completions_seed (context , num_seeds = 1 )
262266 completion = await request_completion (context .prompts .pop (),
263267 seeds [0 ] if seeds is not None else seeds ,
@@ -272,8 +276,11 @@ async def step_request_completion(context, api_error: Literal['raised'] | str):
272276 context .tasks_result .append (completion )
273277 if context .debug :
274278 print (f"Completion response: { completion } " )
275- if expect_api_error :
279+ if api_error == 'raised' :
276280 assert completion == 401 , f"completion must be an 401 status code: { completion } "
281+ elif api_error .isdigit ():
282+ api_error_code = int (api_error )
283+ assert completion == api_error_code , f"completion must be an { api_error_code } status code: { completion } "
277284
278285
279286@step ('{predicted_n:d} tokens are predicted matching {re_content}' )
@@ -645,6 +652,9 @@ def step_assert_embeddings(context):
645652 for embedding in context .embeddings :
646653 assert_embeddings (embedding )
647654
655+ @step ('embeddings request with {api_error_code:d} api error' )
656+ def step_assert_embeddings (context , api_error_code : int ):
657+ assert context .embeddings == api_error_code , f"embeddings request must return code { api_error_code } , but got { context .embeddings } "
648658
649659@step ('an OAI compatible embeddings computation request for' )
650660@async_run_until_complete
@@ -1089,15 +1099,17 @@ async def oai_chat_completions(user_prompt,
10891099 return completion_response
10901100
10911101
1092- async def request_embedding (content , seed , base_url = None ) -> list [list [float ]]:
1102+ async def request_embedding (content , seed , base_url = None ) -> list [list [float ]] | int :
10931103 async with aiohttp .ClientSession (timeout = DEFAULT_TIMEOUT_SECONDS ) as session :
10941104 async with session .post (f'{ base_url } /embedding' ,
10951105 json = {
10961106 "content" : content ,
10971107 }) as response :
1098- assert response .status == 200
1099- response_json = await response .json ()
1100- return [response_json ['embedding' ]]
1108+ if response .status == 200 :
1109+ response_json = await response .json ()
1110+ return [response_json ['embedding' ]]
1111+ else :
1112+ return response .status
11011113
11021114
11031115async def request_oai_embeddings (input , seed ,
@@ -1372,6 +1384,8 @@ def start_server_background(context):
13721384 server_args .append ('--verbose' )
13731385 if context .lora_file :
13741386 server_args .extend (['--lora' , context .lora_file ])
1387+ if context .disable_ctx_shift :
1388+ server_args .extend (['--no-context-shift' ])
13751389
13761390 args = [str (arg ) for arg in [context .server_path , * server_args ]]
13771391 print (f"bench: starting server with: { ' ' .join (args )} " )
0 commit comments