@@ -440,7 +440,7 @@ def apply_rotary_pos_emb_ds(q, k, cos, sin, position_ids, unsqueeze_dim=1):
440440 "mistralai/Mistral-Small-3.1-24B-Instruct-2503" : {
441441 "llm_models_subdir" : "Mistral-Small-3.1-24B-Instruct-2503" ,
442442 "model_factory" : "Mistral3VLM" ,
443- "compile_backend" : "torch-simple" ,
443+ # "compile_backend": "torch-simple",
444444 "model_kwargs" : {
445445 "text_config" : {"num_hidden_layers" : 2 },
446446 "vision_config" : {"num_hidden_layers" : 2 },
@@ -473,10 +473,8 @@ def get_small_model_config(model_hub_id: str, **llm_args_kwargs) -> Dict[str, An
473473
474474 # add some defaults to llm_args
475475 llm_args ["skip_loading_weights" ] = True # No weight loading to speed up things
476- llm_args ["free_mem_ratio" ] = 0.00 # we don't need the cache and it may cause OOM issues
477476 llm_args ["attn_page_size" ] = 4 # Make sure paging is activated despite small max_tokens
478477 llm_args ["max_batch_size" ] = 2 # Minimum batching to speed up things
479-
480478 # update with custom llm_args kwargs
481479 llm_args .update (llm_args_kwargs )
482480
@@ -494,10 +492,16 @@ def get_small_model_config(model_hub_id: str, **llm_args_kwargs) -> Dict[str, An
494492
495493
496494def get_small_model_config_pytest_param (
497- model_hub_id : str , pytest_param_kwargs = None , ** llm_args_kwargs
495+ model_hub_id : str ,
496+ attn_backend : str ,
497+ compile_backend : str ,
498+ pytest_param_kwargs = None ,
499+ ** llm_args_kwargs ,
498500):
499501 return pytest .param (
500502 get_small_model_config (model_hub_id , ** llm_args_kwargs ),
503+ attn_backend ,
504+ compile_backend ,
501505 id = model_hub_id ,
502506 ** (pytest_param_kwargs or {}),
503507 )
0 commit comments