Skip to content

Commit a2423c1

Browse files
committed
fix(sglang): made it work
1 parent 2080309 commit a2423c1

File tree

1 file changed

+27
-163
lines changed

1 file changed

+27
-163
lines changed

mellea/backends/sglang.py

Lines changed: 27 additions & 163 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
"""A backend that uses a VLLM in the current process.
1+
"""A backend that uses a SGLang in the current process.
22
3-
The purpose of the VLLM backend is to provide a locally running fast inference engine.
3+
The purpose of the SGLang backend is to provide a locally running fast inference engine.
44
"""
55

66
from __future__ import annotations
@@ -15,12 +15,9 @@
1515
from collections.abc import Callable
1616
from typing import TYPE_CHECKING, Any, Optional
1717

18-
import msgspec
19-
import outlines
20-
import outlines_core
18+
import sglang as sgl # type:ignore
2119
import torch
22-
import vllm
23-
from transformers import PreTrainedTokenizerBase
20+
from transformers import AutoTokenizer, PreTrainedTokenizerBase
2421

2522
from mellea.backends import BaseModelSubclass
2623
from mellea.backends.formatter import Formatter, FormatterBackend, TemplateFormatter
@@ -44,19 +41,13 @@
4441
from mellea.stdlib.chat import Message
4542
from mellea.stdlib.requirement import LLMaJRequirement, Requirement
4643

47-
assert outlines, "outlines needs to be present to make outlines_core work"
4844

45+
class LocalSGLangBackend(FormatterBackend):
46+
"""The LocalSGLangBackend uses SGLang's python offline inference interface, and uses a Formatter to convert `Component`s into prompts.
4947
50-
class LocalVLLMBackend(FormatterBackend):
51-
"""The LocalVLLMBackend uses vLLM's python interface for inference, and uses a Formatter to convert `Component`s into prompts.
48+
This backend is designed for running a HF model for small-scale inference locally on your machine.
5249
53-
The support for Activated LoRAs (ALoras)](https://arxiv.org/pdf/2504.12397) is planned.
54-
55-
This backend is designed for running an HF model for small-scale inference locally on your machine.
56-
57-
Its throughput is generally higher than that of LocalHFBackend.
58-
However, it takes longer to load the weights during the instantiation.
59-
Also, if you submit a request one by one, it can be slower.
50+
Its throughput is generally higher than that of LocalHFBackend and LocalVLLMBackend.
6051
"""
6152

6253
def __init__(
@@ -85,7 +76,7 @@ def __init__(
8576
# These are usually values that must be extracted before hand or that are common among backend providers
8677
self.to_mellea_model_opts_map = {
8778
# "system": ModelOption.SYSTEM_PROMPT,
88-
"max_tokens": ModelOption.MAX_NEW_TOKENS,
79+
"max_new_tokens": ModelOption.MAX_NEW_TOKENS,
8980
"seed": ModelOption.SEED,
9081
"temperature": ModelOption.TEMPERATURE,
9182
}
@@ -96,7 +87,7 @@ def __init__(
9687
# will be omitted here so that they will be removed when model_options are processed
9788
# for the call to the model.
9889
self.from_mellea_model_opts_map = {
99-
ModelOption.MAX_NEW_TOKENS: "max_tokens",
90+
ModelOption.MAX_NEW_TOKENS: "max_new_tokens",
10091
ModelOption.SEED: "seed",
10192
ModelOption.TEMPERATURE: "temperature",
10293
}
@@ -111,89 +102,11 @@ def __init__(
111102
)
112103
self._hf_model_id = model_id.hf_model_name
113104

114-
# vllm requires some model options during instantiation.
115-
engine_args = self._simplify_and_merge(model_options)
116-
engine_args = self._make_backend_specific_and_remove(
117-
engine_args, vllm.EngineArgs
118-
)
119-
120-
logger = FancyLogger.get_logger()
121-
# Get the model and tokenizer.
122-
# Getting vllm instantiated is tricky as it does not automatically detect some of these parameters.
123-
engine_args["gpu_memory_utilization"] = engine_args.get(
124-
"gpu_memory_utilization", 0.9
125-
)
126-
engine_args["max_num_seqs"] = engine_args.get("max_num_seqs", 16)
127-
engine_args["max_model_len"] = engine_args.get("max_model_len", 16384)
128-
logger.info(
129-
f"Instantiating vllm with the following model parameters:\n"
130-
f"gpu_memory_utilization: {engine_args['gpu_memory_utilization']}\n"
131-
f"max_model_len: {engine_args['max_model_len']}\n"
132-
f"max_num_seqs: {engine_args['max_num_seqs']}\n"
133-
)
134-
retry = 0
135-
while True:
136-
retry += 1
137-
try:
138-
self._model = vllm.LLM(
139-
model=self._hf_model_id, task="generate", **engine_args
140-
)
141-
break
142-
except torch._dynamo.exc.BackendCompilerFailed as e:
143-
# example:
144-
# torch._dynamo.exc.BackendCompilerFailed: backend='<vllm.compilation.backends.VllmBackend object at 0x7f6d3f341730>' raised:
145-
# RuntimeError: vLLM failed to compile the model. The most likely reason for this is that a previous compilation failed, leading to a corrupted compilation artifact. We recommend trying to remove ~/.cache/vllm/torch_compile_cache and try again to see the real issue.
146-
147-
if "~/.cache/vllm/torch_compile_cache" in str(e.inner_exception):
148-
logger.warning(
149-
"removing ~/.cache/vllm/torch_compile_cache and retry"
150-
)
151-
shutil.rmtree("~/.cache/vllm/torch_compile_cache")
152-
# then retry
153-
154-
except Exception as e:
155-
logger.info(e)
156-
if retry % 3 == 0:
157-
engine_args["max_model_len"] //= 2
158-
elif retry % 3 == 1:
159-
engine_args["max_num_seqs"] //= 2
160-
elif retry % 3 == 2:
161-
engine_args["gpu_memory_utilization"] *= 0.9
162-
if (
163-
engine_args["max_model_len"] == 0
164-
or engine_args["max_num_seqs"] == 0
165-
or engine_args["gpu_memory_utilization"] < 0.1
166-
):
167-
raise RuntimeError(
168-
"no matter how I reduced max_model_len and max_num_seqs, there is not enough memory! \n"
169-
"final values:\n"
170-
f"gpu_memory_utilization: {engine_args['gpu_memory_utilization']}\n"
171-
f"max_model_len: {engine_args['max_model_len']}\n"
172-
f"max_num_seqs: {engine_args['max_num_seqs']}\n"
173-
)
174-
logger.info(
175-
f"Reducing vllm model parameters to make it fit in the GPU memory.\n"
176-
"current values:\n"
177-
f"gpu_memory_utilization: {engine_args['gpu_memory_utilization']}\n"
178-
f"max_model_len: {engine_args['max_model_len']}\n"
179-
f"max_num_seqs: {engine_args['max_num_seqs']}\n"
180-
)
181-
182-
logger.info(
183-
f"vllm instantiated.\n"
184-
"final model parameters:\n"
185-
f"gpu_memory_utilization: {engine_args['gpu_memory_utilization']}\n"
186-
f"max_model_len: {engine_args['max_model_len']}\n"
187-
f"max_num_seqs: {engine_args['max_num_seqs']}\n"
105+
self._model = sgl.Engine(model_path=self._hf_model_id)
106+
self._tokenizer: PreTrainedTokenizerBase = AutoTokenizer.from_pretrained(
107+
self._hf_model_id
188108
)
189109

190-
self._tokenizer: PreTrainedTokenizerBase = self._model.get_tokenizer() # type:ignore
191-
192-
# see notes in outlines.models.vllm.adapt_tokenizer
193-
self._tokenizer_for_outlines: PreTrainedTokenizerBase = outlines.models.VLLM(
194-
self._model
195-
).tokenizer # type:ignore
196-
197110
def generate_from_context(
198111
self,
199112
action: Component | CBlock,
@@ -262,38 +175,18 @@ def _generate_from_context_standard(
262175
tools=convert_tools_to_json(tools), # type: ignore
263176
)
264177

265-
sampling_params = vllm.SamplingParams(
266-
**self._make_backend_specific_and_remove(
267-
model_options, vllm.SamplingParams
268-
)
178+
sampling_params: dict[str, Any] = self._make_backend_specific_and_remove(
179+
model_options
269180
)
270181

271182
if format is not None:
272-
# outlines.generate.json always parses the resulting json into a python dict.
273-
# We however want to keep it as a json string for later storing it in ModelOutputThunk
274-
schema: dict[str, Any] = format.model_json_schema()
275-
schema_json: str = json.dumps(schema)
276-
regex_str: str = outlines_core.fsm.json_schema.build_regex_from_schema(
277-
schema_json
278-
)
279-
280-
from outlines.processors import RegexLogitsProcessor
183+
sampling_params["json_schema"] = json.dumps(format.model_json_schema())
281184

282-
logits_processor = RegexLogitsProcessor(
283-
regex_str,
284-
tokenizer=self._tokenizer_for_outlines, # type: ignore
285-
)
286-
sampling_params.logits_processors = (
287-
[logits_processor] if logits_processor is not None else []
288-
)
289-
290-
ros: list[vllm.RequestOutput] = self._model.generate( # type: ignore
291-
[input_str], sampling_params=sampling_params
185+
output: dict[str, Any] = self._model.generate( # type: ignore
186+
input_str, sampling_params=sampling_params
292187
) # type: ignore
293188

294-
decoded_results = [ro.outputs[0].text for ro in ros]
295-
296-
decoded_result = decoded_results[0]
189+
decoded_result = output["text"]
297190

298191
else:
299192
raise Exception("Does not yet support non-chat contexts.")
@@ -339,34 +232,19 @@ def _generate_from_raw(
339232

340233
prompts = [self.formatter.print(action) for action in actions]
341234

342-
sampling_params = vllm.SamplingParams(
343-
**self._make_backend_specific_and_remove(model_options, vllm.SamplingParams)
235+
sampling_params: dict[str, Any] = self._make_backend_specific_and_remove(
236+
model_options
344237
)
345238

346239
if format is not None:
347-
schema: dict[str, Any] = format.model_json_schema()
348-
schema_json: str = json.dumps(schema)
349-
regex_str: str = outlines_core.fsm.json_schema.build_regex_from_schema(
350-
schema_json
351-
)
240+
sampling_params["json_schema"] = json.dumps(format.model_json_schema())
352241

353-
from outlines.processors import RegexLogitsProcessor
354-
355-
logits_processor = RegexLogitsProcessor(
356-
regex_str,
357-
tokenizer=self._tokenizer_for_outlines, # type: ignore
358-
)
359-
sampling_params.logits_processors = (
360-
[logits_processor] if logits_processor is not None else []
361-
)
362-
363-
ros: list[vllm.RequestOutput] = self._model.generate( # type: ignore
242+
outputs: list[dict[str, Any]] = self._model.generate( # type: ignore
364243
prompts, sampling_params=sampling_params
365244
) # type: ignore
366245

367-
decoded_results = [ro.outputs[0].text for ro in ros]
368-
369-
results = [ModelOutputThunk(value=text) for text in decoded_results]
246+
decoded_results = [output["text"] for output in outputs]
247+
results = [ModelOutputThunk(value=output["text"]) for output in outputs]
370248

371249
for i, result in enumerate(results):
372250
self.formatter.parse(actions[i], result)
@@ -427,7 +305,7 @@ def _simplify_and_merge(
427305
)
428306

429307
def _make_backend_specific_and_remove(
430-
self, model_options: dict[str, Any], cls: type[Any]
308+
self, model_options: dict[str, Any]
431309
) -> dict[str, Any]:
432310
"""Maps specified Mellea specific keys to their backend specific version and removes any remaining Mellea keys.
433311
@@ -440,18 +318,4 @@ def _make_backend_specific_and_remove(
440318
backend_specific = ModelOption.replace_keys(
441319
model_options, self.from_mellea_model_opts_map
442320
)
443-
backend_specific = ModelOption.remove_special_keys(backend_specific)
444-
try:
445-
# note: dataclasses.Field objects
446-
return {
447-
field.name: backend_specific[field.name]
448-
for field in dataclasses.fields(cls)
449-
if field.name in backend_specific
450-
}
451-
except TypeError:
452-
# note: msgspec.structs.FieldInfo objects
453-
return {
454-
field.name: backend_specific[field.name]
455-
for field in msgspec.structs.fields(cls)
456-
if field.name in backend_specific
457-
}
321+
return ModelOption.remove_special_keys(backend_specific)

0 commit comments

Comments
 (0)