1- """A backend that uses a VLLM in the current process.
1+ """A backend that uses a SGLang in the current process.
22
3- The purpose of the VLLM backend is to provide a locally running fast inference engine.
3+ The purpose of the SGLang backend is to provide a locally running fast inference engine.
44"""
55
66from __future__ import annotations
1515from collections .abc import Callable
1616from typing import TYPE_CHECKING , Any , Optional
1717
18- import msgspec
19- import outlines
20- import outlines_core
18+ import sglang as sgl # type:ignore
2119import torch
22- import vllm
23- from transformers import PreTrainedTokenizerBase
20+ from transformers import AutoTokenizer , PreTrainedTokenizerBase
2421
2522from mellea .backends import BaseModelSubclass
2623from mellea .backends .formatter import Formatter , FormatterBackend , TemplateFormatter
4441from mellea .stdlib .chat import Message
4542from mellea .stdlib .requirement import LLMaJRequirement , Requirement
4643
47- assert outlines , "outlines needs to be present to make outlines_core work"
4844
45+ class LocalSGLangBackend (FormatterBackend ):
46+ """The LocalSGLangBackend uses SGLang's python offline inference interface, and uses a Formatter to convert `Component`s into prompts.
4947
50- class LocalVLLMBackend (FormatterBackend ):
51- """The LocalVLLMBackend uses vLLM's python interface for inference, and uses a Formatter to convert `Component`s into prompts.
48+ This backend is designed for running a HF model for small-scale inference locally on your machine.
5249
53- The support for Activated LoRAs (ALoras)](https://arxiv.org/pdf/2504.12397) is planned.
54-
55- This backend is designed for running an HF model for small-scale inference locally on your machine.
56-
57- Its throughput is generally higher than that of LocalHFBackend.
58- However, it takes longer to load the weights during the instantiation.
59- Also, if you submit a request one by one, it can be slower.
50+ Its throughput is generally higher than that of LocalHFBackend and LocalVLLMBackend.
6051 """
6152
6253 def __init__ (
@@ -85,7 +76,7 @@ def __init__(
8576 # These are usually values that must be extracted before hand or that are common among backend providers
8677 self .to_mellea_model_opts_map = {
8778 # "system": ModelOption.SYSTEM_PROMPT,
88- "max_tokens " : ModelOption .MAX_NEW_TOKENS ,
79+ "max_new_tokens " : ModelOption .MAX_NEW_TOKENS ,
8980 "seed" : ModelOption .SEED ,
9081 "temperature" : ModelOption .TEMPERATURE ,
9182 }
@@ -96,7 +87,7 @@ def __init__(
9687 # will be omitted here so that they will be removed when model_options are processed
9788 # for the call to the model.
9889 self .from_mellea_model_opts_map = {
99- ModelOption .MAX_NEW_TOKENS : "max_tokens " ,
90+ ModelOption .MAX_NEW_TOKENS : "max_new_tokens " ,
10091 ModelOption .SEED : "seed" ,
10192 ModelOption .TEMPERATURE : "temperature" ,
10293 }
@@ -111,89 +102,11 @@ def __init__(
111102 )
112103 self ._hf_model_id = model_id .hf_model_name
113104
114- # vllm requires some model options during instantiation.
115- engine_args = self ._simplify_and_merge (model_options )
116- engine_args = self ._make_backend_specific_and_remove (
117- engine_args , vllm .EngineArgs
118- )
119-
120- logger = FancyLogger .get_logger ()
121- # Get the model and tokenizer.
122- # Getting vllm instantiated is tricky as it does not automatically detect some of these parameters.
123- engine_args ["gpu_memory_utilization" ] = engine_args .get (
124- "gpu_memory_utilization" , 0.9
125- )
126- engine_args ["max_num_seqs" ] = engine_args .get ("max_num_seqs" , 16 )
127- engine_args ["max_model_len" ] = engine_args .get ("max_model_len" , 16384 )
128- logger .info (
129- f"Instantiating vllm with the following model parameters:\n "
130- f"gpu_memory_utilization: { engine_args ['gpu_memory_utilization' ]} \n "
131- f"max_model_len: { engine_args ['max_model_len' ]} \n "
132- f"max_num_seqs: { engine_args ['max_num_seqs' ]} \n "
133- )
134- retry = 0
135- while True :
136- retry += 1
137- try :
138- self ._model = vllm .LLM (
139- model = self ._hf_model_id , task = "generate" , ** engine_args
140- )
141- break
142- except torch ._dynamo .exc .BackendCompilerFailed as e :
143- # example:
144- # torch._dynamo.exc.BackendCompilerFailed: backend='<vllm.compilation.backends.VllmBackend object at 0x7f6d3f341730>' raised:
145- # RuntimeError: vLLM failed to compile the model. The most likely reason for this is that a previous compilation failed, leading to a corrupted compilation artifact. We recommend trying to remove ~/.cache/vllm/torch_compile_cache and try again to see the real issue.
146-
147- if "~/.cache/vllm/torch_compile_cache" in str (e .inner_exception ):
148- logger .warning (
149- "removing ~/.cache/vllm/torch_compile_cache and retry"
150- )
151- shutil .rmtree ("~/.cache/vllm/torch_compile_cache" )
152- # then retry
153-
154- except Exception as e :
155- logger .info (e )
156- if retry % 3 == 0 :
157- engine_args ["max_model_len" ] //= 2
158- elif retry % 3 == 1 :
159- engine_args ["max_num_seqs" ] //= 2
160- elif retry % 3 == 2 :
161- engine_args ["gpu_memory_utilization" ] *= 0.9
162- if (
163- engine_args ["max_model_len" ] == 0
164- or engine_args ["max_num_seqs" ] == 0
165- or engine_args ["gpu_memory_utilization" ] < 0.1
166- ):
167- raise RuntimeError (
168- "no matter how I reduced max_model_len and max_num_seqs, there is not enough memory! \n "
169- "final values:\n "
170- f"gpu_memory_utilization: { engine_args ['gpu_memory_utilization' ]} \n "
171- f"max_model_len: { engine_args ['max_model_len' ]} \n "
172- f"max_num_seqs: { engine_args ['max_num_seqs' ]} \n "
173- )
174- logger .info (
175- f"Reducing vllm model parameters to make it fit in the GPU memory.\n "
176- "current values:\n "
177- f"gpu_memory_utilization: { engine_args ['gpu_memory_utilization' ]} \n "
178- f"max_model_len: { engine_args ['max_model_len' ]} \n "
179- f"max_num_seqs: { engine_args ['max_num_seqs' ]} \n "
180- )
181-
182- logger .info (
183- f"vllm instantiated.\n "
184- "final model parameters:\n "
185- f"gpu_memory_utilization: { engine_args ['gpu_memory_utilization' ]} \n "
186- f"max_model_len: { engine_args ['max_model_len' ]} \n "
187- f"max_num_seqs: { engine_args ['max_num_seqs' ]} \n "
105+ self ._model = sgl .Engine (model_path = self ._hf_model_id )
106+ self ._tokenizer : PreTrainedTokenizerBase = AutoTokenizer .from_pretrained (
107+ self ._hf_model_id
188108 )
189109
190- self ._tokenizer : PreTrainedTokenizerBase = self ._model .get_tokenizer () # type:ignore
191-
192- # see notes in outlines.models.vllm.adapt_tokenizer
193- self ._tokenizer_for_outlines : PreTrainedTokenizerBase = outlines .models .VLLM (
194- self ._model
195- ).tokenizer # type:ignore
196-
197110 def generate_from_context (
198111 self ,
199112 action : Component | CBlock ,
@@ -262,38 +175,18 @@ def _generate_from_context_standard(
262175 tools = convert_tools_to_json (tools ), # type: ignore
263176 )
264177
265- sampling_params = vllm .SamplingParams (
266- ** self ._make_backend_specific_and_remove (
267- model_options , vllm .SamplingParams
268- )
178+ sampling_params : dict [str , Any ] = self ._make_backend_specific_and_remove (
179+ model_options
269180 )
270181
271182 if format is not None :
272- # outlines.generate.json always parses the resulting json into a python dict.
273- # We however want to keep it as a json string for later storing it in ModelOutputThunk
274- schema : dict [str , Any ] = format .model_json_schema ()
275- schema_json : str = json .dumps (schema )
276- regex_str : str = outlines_core .fsm .json_schema .build_regex_from_schema (
277- schema_json
278- )
279-
280- from outlines .processors import RegexLogitsProcessor
183+ sampling_params ["json_schema" ] = json .dumps (format .model_json_schema ())
281184
282- logits_processor = RegexLogitsProcessor (
283- regex_str ,
284- tokenizer = self ._tokenizer_for_outlines , # type: ignore
285- )
286- sampling_params .logits_processors = (
287- [logits_processor ] if logits_processor is not None else []
288- )
289-
290- ros : list [vllm .RequestOutput ] = self ._model .generate ( # type: ignore
291- [input_str ], sampling_params = sampling_params
185+ output : dict [str , Any ] = self ._model .generate ( # type: ignore
186+ input_str , sampling_params = sampling_params
292187 ) # type: ignore
293188
294- decoded_results = [ro .outputs [0 ].text for ro in ros ]
295-
296- decoded_result = decoded_results [0 ]
189+ decoded_result = output ["text" ]
297190
298191 else :
299192 raise Exception ("Does not yet support non-chat contexts." )
@@ -339,34 +232,19 @@ def _generate_from_raw(
339232
340233 prompts = [self .formatter .print (action ) for action in actions ]
341234
342- sampling_params = vllm . SamplingParams (
343- ** self . _make_backend_specific_and_remove ( model_options , vllm . SamplingParams )
235+ sampling_params : dict [ str , Any ] = self . _make_backend_specific_and_remove (
236+ model_options
344237 )
345238
346239 if format is not None :
347- schema : dict [str , Any ] = format .model_json_schema ()
348- schema_json : str = json .dumps (schema )
349- regex_str : str = outlines_core .fsm .json_schema .build_regex_from_schema (
350- schema_json
351- )
240+ sampling_params ["json_schema" ] = json .dumps (format .model_json_schema ())
352241
353- from outlines .processors import RegexLogitsProcessor
354-
355- logits_processor = RegexLogitsProcessor (
356- regex_str ,
357- tokenizer = self ._tokenizer_for_outlines , # type: ignore
358- )
359- sampling_params .logits_processors = (
360- [logits_processor ] if logits_processor is not None else []
361- )
362-
363- ros : list [vllm .RequestOutput ] = self ._model .generate ( # type: ignore
242+ outputs : list [dict [str , Any ]] = self ._model .generate ( # type: ignore
364243 prompts , sampling_params = sampling_params
365244 ) # type: ignore
366245
367- decoded_results = [ro .outputs [0 ].text for ro in ros ]
368-
369- results = [ModelOutputThunk (value = text ) for text in decoded_results ]
246+ decoded_results = [output ["text" ] for output in outputs ]
247+ results = [ModelOutputThunk (value = output ["text" ]) for output in outputs ]
370248
371249 for i , result in enumerate (results ):
372250 self .formatter .parse (actions [i ], result )
@@ -427,7 +305,7 @@ def _simplify_and_merge(
427305 )
428306
429307 def _make_backend_specific_and_remove (
430- self , model_options : dict [str , Any ], cls : type [ Any ]
308+ self , model_options : dict [str , Any ]
431309 ) -> dict [str , Any ]:
432310 """Maps specified Mellea specific keys to their backend specific version and removes any remaining Mellea keys.
433311
@@ -440,18 +318,4 @@ def _make_backend_specific_and_remove(
440318 backend_specific = ModelOption .replace_keys (
441319 model_options , self .from_mellea_model_opts_map
442320 )
443- backend_specific = ModelOption .remove_special_keys (backend_specific )
444- try :
445- # note: dataclasses.Field objects
446- return {
447- field .name : backend_specific [field .name ]
448- for field in dataclasses .fields (cls )
449- if field .name in backend_specific
450- }
451- except TypeError :
452- # note: msgspec.structs.FieldInfo objects
453- return {
454- field .name : backend_specific [field .name ]
455- for field in msgspec .structs .fields (cls )
456- if field .name in backend_specific
457- }
321+ return ModelOption .remove_special_keys (backend_specific )
0 commit comments