Skip to content

Commit dfb53d1

Browse files
feat: add reasoning (#752)
1 parent a97b83d commit dfb53d1

File tree

11 files changed

+207
-69
lines changed

11 files changed

+207
-69
lines changed

docs/how-to/llms/use_llms.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ LLMs in Ragbits allow you to customize the behavior of the model using various o
2929

3030
### LiteLLM Options
3131

32-
The `LiteLLMOptions` class provides options for remote LLMs, aligning with the LiteLLM API. These options allow you to control the behavior of models from various providers. Each of the option is described in the [LiteLLM documentation](https://docs.litellm.ai/docs/completion/input).
32+
The `LiteLLMOptions` class provides options for remote LLMs, aligning with the LiteLLM API. These options allow you to control the behavior of models from various providers. Each of the option is described in the [LiteLLM documentation](https://docs.litellm.ai/docs/completion/input) and [Reasoning Documentation](https://docs.litellm.ai/docs/reasoning_content)
3333

3434
Example usage:
3535
```python
@@ -47,6 +47,9 @@ response = llm.generate("Write a short story about a robot learning to paint.")
4747
print(response)
4848
```
4949

50+
!!! warning
51+
If you provide reasoning_effort to the OpenAI model, [the reasoning content will not be returned](https://platform.openai.com/docs/guides/reasoning?api-mode=responses).
52+
5053
## Using Local LLMs
5154

5255
For guidance on setting up and using local models in Ragbits, refer to the [Local LLMs Guide](https://ragbits.deepsense.ai/how-to/llms/use_local_llms/).

examples/core/llms/reasoning.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
"""
2+
Ragbits Core Example: Reasoning with LLM
3+
4+
This example demonstrates how to use reasoning with LLM.
5+
6+
To run the script, execute the following command:
7+
8+
```bash
9+
uv run examples/core/llms/reasoning.py
10+
```
11+
"""
12+
13+
# /// script
14+
# requires-python = ">=3.10"
15+
# dependencies = [
16+
# "ragbits-core",
17+
# ]
18+
# ///
19+
20+
import asyncio
21+
22+
from ragbits.core.llms import LiteLLM, LiteLLMOptions
23+
24+
25+
async def main() -> None:
26+
"""
27+
Run the example.
28+
"""
29+
options = LiteLLMOptions(reasoning_effort="medium")
30+
model = LiteLLM(model_name="claude-3-7-sonnet-20250219", default_options=options)
31+
response = await model.generate_with_metadata(
32+
"Do you like Jazz?",
33+
)
34+
print(f"reasoning: {response.reasoning}")
35+
36+
options = LiteLLMOptions(thinking={"type": "enabled", "budget_tokens": 1024})
37+
model = LiteLLM(model_name="claude-3-7-sonnet-20250219", default_options=options)
38+
response = await model.generate_with_metadata(
39+
"Do you like Jazz?",
40+
)
41+
print(f"reasoning: {response.reasoning}")
42+
43+
44+
if __name__ == "__main__":
45+
asyncio.run(main())

packages/ragbits-core/CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
## Unreleased
44

5+
- Add support for Reasoning models (#752)
56
- Fix issue with cost calculation for some models (#748)
67
- Fix issue with improper convertion to json of tool call arguments (#737)
78
- Added Google Drive support (#686)

packages/ragbits-core/src/ragbits/core/llms/base.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -197,13 +197,18 @@ def __repr__(self) -> str:
197197
)
198198

199199

200+
class Reasoning(str):
201+
"""A class for reasoning streaming"""
202+
203+
200204
class LLMResponseWithMetadata(BaseModel, Generic[PromptOutputT]):
201205
"""
202206
A schema of output with metadata
203207
"""
204208

205209
content: PromptOutputT
206210
metadata: dict = {}
211+
reasoning: str | None = None
207212
tool_calls: list[ToolCall] | None = None
208213
usage: Usage | None = None
209214

@@ -571,12 +576,14 @@ async def generate_with_metadata(
571576
)
572577

573578
content = response.pop("response")
579+
reasoning = response.pop("reasoning", None)
574580

575581
if isinstance(prompt, BasePromptWithParser) and content:
576582
content = await prompt.parse_response(content)
577583

578584
response_with_metadata = LLMResponseWithMetadata[type(content)]( # type: ignore
579585
content=content,
586+
reasoning=reasoning,
580587
tool_calls=tool_calls,
581588
metadata=response,
582589
usage=usage,
@@ -623,7 +630,7 @@ def generate_streaming(
623630
*,
624631
tools: None = None,
625632
options: LLMClientOptionsT | None = None,
626-
) -> LLMResultStreaming[str]: ...
633+
) -> LLMResultStreaming[str | Reasoning]: ...
627634

628635
@overload
629636
def generate_streaming(
@@ -632,7 +639,7 @@ def generate_streaming(
632639
*,
633640
tools: list[Tool],
634641
options: LLMClientOptionsT | None = None,
635-
) -> LLMResultStreaming[str | ToolCall]: ...
642+
) -> LLMResultStreaming[str | Reasoning | ToolCall]: ...
636643

637644
def generate_streaming(
638645
self,
@@ -661,7 +668,7 @@ async def _stream_internal(
661668
*,
662669
tools: list[Tool] | None = None,
663670
options: LLMClientOptionsT | None = None,
664-
) -> AsyncGenerator[str | ToolCall | LLMResponseWithMetadata, None]:
671+
) -> AsyncGenerator[str | Reasoning | ToolCall | LLMResponseWithMetadata, None]:
665672
with trace(model_name=self.model_name, prompt=prompt, options=repr(options)) as outputs:
666673
merged_options = (self.default_options | options) if options else self.default_options
667674
if isinstance(prompt, str | list):
@@ -679,12 +686,17 @@ async def _stream_internal(
679686
)
680687

681688
content = ""
689+
reasoning = ""
682690
tool_calls = []
683691
usage_data = {}
684692
async for chunk in response:
685693
if text := chunk.get("response"):
686-
content += text
687-
yield text
694+
if chunk.get("reasoning"):
695+
reasoning += text
696+
yield Reasoning(text)
697+
else:
698+
content += text
699+
yield text
688700

689701
if tools and (_tool_calls := chunk.get("tool_calls")):
690702
for tool_call in _tool_calls:
@@ -706,6 +718,7 @@ async def _stream_internal(
706718

707719
outputs.response = LLMResponseWithMetadata[type(content or None)]( # type: ignore
708720
content=content or None,
721+
reasoning=reasoning or None,
709722
tool_calls=tool_calls or None,
710723
usage=usage,
711724
)

packages/ragbits-core/src/ragbits/core/llms/exceptions.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,3 +70,12 @@ class LLMNotSupportingToolUseError(LLMError):
7070

7171
def __init__(self, message: str = "There are tools provided, but given LLM doesn't support tool use.") -> None:
7272
super().__init__(message)
73+
74+
75+
class LLMNotSupportingReasoningEffortError(LLMError):
76+
"""
77+
Raised when there is reasoning effort provided, but LLM doesn't support it.
78+
"""
79+
80+
def __init__(self, model_name: str) -> None:
81+
super().__init__(f"Model {model_name} does not support reasoning effort.")

packages/ragbits-core/src/ragbits/core/llms/litellm.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import asyncio
22
import time
33
from collections.abc import AsyncGenerator, Callable, Iterable
4-
from typing import Any
4+
from typing import Any, Literal
55

66
import litellm
77
import tiktoken
@@ -17,6 +17,7 @@
1717
LLMEmptyResponseError,
1818
LLMNotSupportingImagesError,
1919
LLMNotSupportingPdfsError,
20+
LLMNotSupportingReasoningEffortError,
2021
LLMNotSupportingToolUseError,
2122
LLMResponseError,
2223
LLMStatusError,
@@ -29,6 +30,7 @@ class LiteLLMOptions(LLMOptions):
2930
"""
3031
Dataclass that represents all available LLM call options for the LiteLLM client.
3132
Each of them is described in the [LiteLLM documentation](https://docs.litellm.ai/docs/completion/input).
33+
Reasoning effort and thinking are described in [LiteLLM Reasoning documentation](https://docs.litellm.ai/docs/reasoning_content)
3234
"""
3335

3436
frequency_penalty: float | None | NotGiven = NOT_GIVEN
@@ -45,6 +47,8 @@ class LiteLLMOptions(LLMOptions):
4547
mock_response: str | None | NotGiven = NOT_GIVEN
4648
tpm: int | None | NotGiven = NOT_GIVEN
4749
rpm: int | None | NotGiven = NOT_GIVEN
50+
reasoning_effort: Literal["low", "medium", "high"] | None | NotGiven = NOT_GIVEN
51+
thinking: dict | None | NotGiven = NOT_GIVEN
4852

4953

5054
class LiteLLM(LLM[LiteLLMOptions]):
@@ -185,6 +189,9 @@ async def _call(
185189
if tools and not litellm.supports_function_calling(self.model_name):
186190
raise LLMNotSupportingToolUseError()
187191

192+
if options.reasoning_effort and not litellm.supports_reasoning(self.model_name):
193+
raise LLMNotSupportingReasoningEffortError(self.model_name)
194+
188195
start_time = time.perf_counter()
189196
raw_responses = await asyncio.gather(
190197
*(
@@ -209,6 +216,7 @@ async def _call(
209216

210217
result = {}
211218
result["response"] = response.choices[0].message.content # type: ignore
219+
result["reasoning"] = getattr(response.choices[0].message, "reasoning_content", None) # type: ignore
212220
result["throughput"] = throughput_batch / float(len(raw_responses))
213221

214222
result["tool_calls"] = (
@@ -274,6 +282,9 @@ async def _call_streaming(
274282
if tools and not litellm.supports_function_calling(self.model_name):
275283
raise LLMNotSupportingToolUseError()
276284

285+
if options.reasoning_effort and not litellm.supports_reasoning(self.model_name):
286+
raise LLMNotSupportingReasoningEffortError(self.model_name)
287+
277288
response_format = self._get_response_format(output_schema=prompt.output_schema(), json_mode=prompt.json_mode)
278289
input_tokens = self.count_tokens(prompt)
279290

@@ -288,7 +299,6 @@ async def _call_streaming(
288299
stream=True,
289300
stream_options={"include_usage": True},
290301
)
291-
292302
if not response.completion_stream and not response.choices: # type: ignore
293303
raise LLMEmptyResponseError()
294304

@@ -298,7 +308,8 @@ async def response_to_async_generator(response: CustomStreamWrapper) -> AsyncGen
298308
tool_calls: list[dict] = []
299309

300310
async for item in response:
301-
if content := item.choices[0].delta.content:
311+
reasoning_content = getattr(item.choices[0].delta, "reasoning_content", None)
312+
if content := item.choices[0].delta.content or reasoning_content:
302313
output_tokens += 1
303314
if output_tokens == 1:
304315
record_metric(
@@ -308,7 +319,8 @@ async def response_to_async_generator(response: CustomStreamWrapper) -> AsyncGen
308319
model=self.model_name,
309320
prompt=prompt.__class__.__name__,
310321
)
311-
yield {"response": content}
322+
323+
yield {"response": content, "reasoning": bool(reasoning_content)}
312324

313325
if tool_calls_delta := item.choices[0].delta.tool_calls:
314326
for tool_call_chunk in tool_calls_delta:
@@ -412,6 +424,12 @@ async def _get_litellm_response(
412424
**options.dict(),
413425
}
414426

427+
supported_openai_params = litellm.get_supported_openai_params(model=self.model_name) or []
428+
if "reasoning_effort" not in supported_openai_params:
429+
completion_kwargs.pop("reasoning_effort")
430+
if "thinking" not in supported_openai_params:
431+
completion_kwargs.pop("thinking")
432+
415433
if stream_options is not None:
416434
completion_kwargs["stream_options"] = stream_options
417435

packages/ragbits-core/src/ragbits/core/llms/local.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@ async def _call(
157157
for i, response in enumerate(responses):
158158
result = {}
159159
result["response"] = self.tokenizer.decode(response, skip_special_tokens=True)
160+
result["reasoning"] = None
160161
prompt_tokens = tokens_in[i]
161162
completion_tokens = sum(response != self.tokenizer._pad_token_type_id)
162163
result["usage"] = {
@@ -222,7 +223,7 @@ async def streamer_to_async_generator(
222223
prompt=prompt.__class__.__name__,
223224
)
224225

225-
yield {"response": text}
226+
yield {"response": text, "reasoning": False}
226227
await asyncio.sleep(0.0)
227228

228229
generation_thread.join()

packages/ragbits-core/src/ragbits/core/llms/mock.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ class MockLLMOptions(LLMOptions):
1414
response: str | NotGiven = NOT_GIVEN
1515
response_stream: list[str] | NotGiven = NOT_GIVEN
1616
tool_calls: list[dict] | NotGiven = NOT_GIVEN
17+
reasoning: str | NotGiven = NOT_GIVEN
18+
reasoning_stream: list[str] | NotGiven = NOT_GIVEN
1719

1820

1921
class MockLLM(LLM[MockLLMOptions]):
@@ -69,6 +71,7 @@ async def _call( # noqa: PLR6301
6971
prompt = list(prompt)
7072
self.calls.extend([p.chat for p in prompt])
7173
response = "mocked response" if isinstance(options.response, NotGiven) else options.response
74+
reasoning = None if isinstance(options.reasoning, NotGiven) else options.reasoning
7275
tool_calls = (
7376
None
7477
if isinstance(options.tool_calls, NotGiven)
@@ -78,6 +81,7 @@ async def _call( # noqa: PLR6301
7881
return [
7982
{
8083
"response": response,
84+
"reasoning": reasoning,
8185
"tool_calls": tool_calls,
8286
"is_mocked": True,
8387
"throughput": 1 / len(prompt),
@@ -107,10 +111,16 @@ async def generator() -> AsyncGenerator[dict, None]:
107111
):
108112
yield {"tool_calls": options.tool_calls}
109113
elif not isinstance(options.response_stream, NotGiven):
114+
if not isinstance(options.reasoning_stream, NotGiven):
115+
for reasoning in options.reasoning_stream:
116+
yield {"response": reasoning, "reasoning": True}
110117
for response in options.response_stream:
111118
yield {"response": response}
112119
elif not isinstance(options.response, NotGiven):
120+
if not isinstance(options.reasoning, NotGiven):
121+
yield {"response": options.reasoning, "reasoning": True}
113122
yield {"response": options.response}
123+
114124
else:
115125
yield {"response": "mocked response"}
116126

0 commit comments

Comments
 (0)