Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,6 @@ def __init__(self, config=None, user_session=None):
if functions_schema_filtered:
self.extra_params["tools"] = functions_schema_filtered

if self.stream and "tools" in self.extra_params:
logger.warning("Streaming responses with tools is not supported; disabling streaming.")
self.stream = False

self.mcp_configs = {}
allowed_mcp_configs = self.config.get("mcp_configs", [])
if allowed_mcp_configs:
Expand Down
120 changes: 113 additions & 7 deletions backend/openedx_ai_extensions/processors/llm_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,12 +245,26 @@ def _completion_with_tools(self, tool_calls, params):
"""Handle tool calls recursively until no more tool calls are present."""
for tool_call in tool_calls:
function_name = tool_call.function.name

# Ensure tool exists
if function_name not in AVAILABLE_TOOLS:
logger.error(f"Tool '{function_name}' requested by LLM but not available locally.")
continue

function_to_call = AVAILABLE_TOOLS[function_name]
function_args = json.loads(tool_call.function.arguments)
logger.info(f"[LLM] Tool call: {function_to_call}")

try:
function_args = json.loads(tool_call.function.arguments)
function_response = function_to_call(**function_args)
logger.info(f"[LLM] Response from tool call: {function_response}")
except json.JSONDecodeError:
function_response = "Error: Invalid JSON arguments provided."
logger.error(f"Failed to parse JSON arguments for {function_name}")
except Exception as e: # pylint: disable=broad-exception-caught
function_response = f"Error executing tool: {str(e)}"
logger.error(f"Error executing tool {function_name}: {e}")

function_response = function_to_call(
**function_args,
)
params["messages"].append(
{
"tool_call_id": tool_call.id,
Expand All @@ -263,10 +277,9 @@ def _completion_with_tools(self, tool_calls, params):
# Call completion again with updated messages
response = completion(**params)

# For streaming, return the generator immediately
# Tool calls are not supported in streaming mode
# For streaming, we need to handle the stream to detect tool calls
if params.get("stream"):
return response
return self._handle_streaming_tool_calls(response, params)

# For non-streaming, check for tool calls and handle recursively
new_tool_calls = response.choices[0].message.tool_calls
Expand All @@ -276,6 +289,99 @@ def _completion_with_tools(self, tool_calls, params):

return response

def _handle_streaming_tool_calls(self, response, params):
"""
Generator that handles streaming responses containing tool calls.
It accumulates tool call chunks, executes them, and recursively calls completion.
"""
tool_calls_buffer = {} # index -> {id, function: {name, arguments}}
accumulating_tools = False
logger.info("[LLM STREAM] Streaming tool calls")

for chunk in response:
delta = chunk.choices[0].delta

# If there is content, yield it immediately to the user
if delta.content:
yield chunk

# If there are tool calls, buffer them
if delta.tool_calls:
if not accumulating_tools:
logger.info("[AI STREAM] Start: buffer function")
accumulating_tools = True
for tc_chunk in delta.tool_calls:
idx = tc_chunk.index

if idx not in tool_calls_buffer:
tool_calls_buffer[idx] = {
"id": "",
"type": "function",
"function": {"name": "", "arguments": ""}
}

if tc_chunk.id:
tool_calls_buffer[idx]["id"] += tc_chunk.id

if tc_chunk.function:
if tc_chunk.function.name:
tool_calls_buffer[idx]["function"]["name"] += tc_chunk.function.name
if tc_chunk.function.arguments:
tool_calls_buffer[idx]["function"]["arguments"] += tc_chunk.function.arguments

# If we accumulated tool calls, reconstruct them and recurse
if accumulating_tools and tool_calls_buffer:

# Helper classes to mimic the object structure LiteLLM expects in _completion_with_tools
class FunctionMock:
def __init__(self, name, arguments):
self.name = name
self.arguments = arguments

class ToolCallMock:
def __init__(self, t_id, name, arguments):
self.id = t_id
self.function = FunctionMock(name, arguments)
self.type = "function"

# Prepare list for the recursive call
reconstructed_tool_calls = []

# Prepare message to append to history (as dict for JSON serialization)
assistant_message_tool_calls = []

for idx in sorted(tool_calls_buffer.keys()):
data = tool_calls_buffer[idx]

# Create object for internal logic
tc_obj = ToolCallMock(
t_id=data['id'],
name=data['function']['name'],
arguments=data['function']['arguments']
)
reconstructed_tool_calls.append(tc_obj)

# Create dict for history
assistant_message_tool_calls.append({
"id": data['id'],
"type": "function",
"function": {
"name": data['function']['name'],
"arguments": data['function']['arguments']
}
})

# Append the Assistant's intent to call tools to the history
params["messages"].append({
"role": "assistant",
"content": None,
"tool_calls": assistant_message_tool_calls
})

# Recursively call completion with the reconstructed tools
# yield from delegates the generation of the next stream (result of tool) to this generator
yield from self._completion_with_tools(reconstructed_tool_calls, params)

def _responses_with_tools(self, tool_calls, params):
"""Handle tool calls recursively until no more tool calls are present."""
for tool_call in tool_calls:
Expand Down
33 changes: 0 additions & 33 deletions backend/tests/test_litellm_base_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,39 +577,6 @@ def test_non_string_provider_raises_error(mock_settings): # pylint: disable=unu
LitellmProcessor(config=config, user_session=None)


# ============================================================================
# Streaming with Tools Tests
# ============================================================================


@patch.object(settings, "AI_EXTENSIONS", new_callable=lambda: {
"default": {
"MODEL": "openai/gpt-4",
}
})
@pytest.mark.django_db
def test_streaming_with_tools_disables_streaming(mock_settings): # pylint: disable=unused-argument
"""
Test that streaming is disabled when tools are enabled.
"""
config = {
"LitellmProcessor": {
"stream": True,
"enabled_tools": ["roll_dice"],
}
}
with patch('openedx_ai_extensions.processors.litellm_base_processor.logger') as mock_logger:
processor = LitellmProcessor(config=config, user_session=None)

# Verify streaming was disabled
assert processor.stream is False

# Verify warning was logged
mock_logger.warning.assert_called_once_with(
"Streaming responses with tools is not supported; disabling streaming."
)


# ============================================================================
# MCP Configs Tests
# ============================================================================
Expand Down
135 changes: 134 additions & 1 deletion backend/tests/test_llm_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from opaque_keys.edx.keys import CourseKey
from opaque_keys.edx.locator import BlockUsageLocator

from openedx_ai_extensions.functions.decorators import AVAILABLE_TOOLS
from openedx_ai_extensions.processors.llm_processor import LLMProcessor
from openedx_ai_extensions.workflows.models import AIWorkflowProfile, AIWorkflowScope, AIWorkflowSession

Expand Down Expand Up @@ -89,8 +90,9 @@ def llm_processor(user_session, settings): # pylint: disable=redefined-outer-na

class MockDelta:
"""Mock for the delta object in a streaming chunk."""
def __init__(self, content):
def __init__(self, content, tool_calls=None):
self.content = content
self.tool_calls = tool_calls


class MockChoice:
Expand Down Expand Up @@ -130,6 +132,26 @@ def __init__(self, content, is_stream=True):
]


class MockUsage:
"""Mock for usage statistics."""
def __init__(self, total_tokens=10):
self.total_tokens = total_tokens


class MockStreamChunk:
"""Mock for a streaming chunk."""
def __init__(self, content, is_delta=True):
self.usage = MockUsage(total_tokens=5)
self.delta = None
self.choices = []

if is_delta:
mock_delta = MockDelta(content)
self.choices = [MockChoice(delta=mock_delta)]
self.delta = mock_delta
self.response = Mock(id="stream-id-123")


# ============================================================================
# Non-Streaming Tests (Standard)
# ============================================================================
Expand Down Expand Up @@ -713,3 +735,114 @@ def test_call_with_custom_prompt_missing_prompt_raises_error(

with pytest.raises(ValueError, match="Custom prompt not provided in configuration"):
processor.process(input_data="Test input")


# ============================================================================
# Streaming Tool Call Tests
# ============================================================================

class MockToolStreamChunk:
"""
Helper for simulating tool call chunks in a stream.
Structure follows: chunk.choices[0].delta.tool_calls[...]
"""

def __init__(self, index, tool_id=None, name=None, arguments=None):
self.usage = MockUsage(total_tokens=5)

# 1. Create the function mock
func_mock = Mock()
func_mock.name = name
func_mock.arguments = arguments

# 2. Create the tool_call mock
tool_call_mock = Mock()
tool_call_mock.index = index
tool_call_mock.id = tool_id
tool_call_mock.function = func_mock

# Construct the delta
delta = MockDelta(content=None, tool_calls=[tool_call_mock])

# Construct the choice
self.choices = [MockChoice(delta=delta)]


@pytest.mark.django_db
@patch("openedx_ai_extensions.processors.llm_processor.completion")
@patch("openedx_ai_extensions.processors.llm_processor.adapt_to_provider")
def test_streaming_tool_execution_recursion(
mock_adapt, mock_completion, llm_processor # pylint: disable=redefined-outer-name
):
"""
Test that streaming correctly handles tool calls:
1. Buffers tool call chunks.
2. Executes the tool.
3. Recursively calls completion with tool output.
4. Yields the final content chunks.
"""
# 1. Setup
mock_adapt.side_effect = lambda provider, params, **kwargs: params

# Configure processor for streaming + custom function calling
llm_processor.config["function"] = "summarize_content" # Uses _call_completion_wrapper
llm_processor.config["stream"] = True
llm_processor.stream = True
llm_processor.extra_params["tools"] = ["mock_tool"] # Needs to pass check in init if strict, but mainly for logic

# 2. Define a Mock Tool
mock_tool_func = Mock(return_value="tool_result_value")

# Patch the global AVAILABLE_TOOLS to include our mock
with patch.dict(AVAILABLE_TOOLS, {"mock_tool": mock_tool_func}):
# 3. Define Stream Sequences

# Sequence 1: The Model decides to call "mock_tool" with args {"arg": "val"}
# Split into multiple chunks to test buffering logic
tool_chunks = [
# Chunk 1: ID and Name
MockToolStreamChunk(index=0, tool_id="call_123", name="mock_tool"),
# Chunk 2: Start of args
MockToolStreamChunk(index=0, arguments='{"arg":'),
# Chunk 3: End of args
MockToolStreamChunk(index=0, arguments=' "val"}'),
]

# Sequence 2: The Model sees the tool result and generates final text
content_chunks = [
MockStreamChunk("Result "),
MockStreamChunk("is "),
MockStreamChunk("tool_result_value"),
]

# Configure completion to return the first sequence, then the second
mock_completion.side_effect = [iter(tool_chunks), iter(content_chunks)]

# 4. Execute
generator = llm_processor.process(context="Ctx")
results = list(generator)

# 5. Assertions

# Check final output (byte encoded by _handle_streaming_completion)
assert b"Result " in results
assert b"is " in results
assert b"tool_result_value" in results

# Check Tool Execution
mock_tool_func.assert_called_once_with(arg="val")

# Check Recursion (completion called twice)
assert mock_completion.call_count == 2

# Verify second call included the tool output
second_call_kwargs = mock_completion.call_args_list[1][1]
messages = second_call_kwargs["messages"]

# Should have: System, (Context/User), Assistant(ToolCall), Tool(Result)
# Finding the tool message
tool_msg = next((m for m in messages if m.get("role") == "tool"), None)
assert tool_msg is not None
assert tool_msg["tool_call_id"] == "call_123"
assert tool_msg["content"] == "tool_result_value"
assert tool_msg["name"] == "mock_tool"