Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions backend/openedx_ai_extensions/workflows/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from django.utils.functional import cached_property
from opaque_keys.edx.django.models import CourseKeyField, UsageKeyField

from openedx_ai_extensions.workflows.orchestrators import BaseOrchestrator
from openedx_ai_extensions.workflows.template_utils import (
get_effective_config,
parse_json5_string,
Expand Down Expand Up @@ -279,17 +280,17 @@ def execute(self, user_input, action, user, running_context) -> dict[str, str |

try:
# Load the orchestrator for this workflow
from openedx_ai_extensions.workflows import orchestrators # pylint: disable=import-outside-toplevel

orchestrator_name = self.profile.orchestrator_class # "DirectLLMResponse"
orchestrator_class = getattr(orchestrators, orchestrator_name)
orchestrator = orchestrator_class(workflow=self, user=user, context=running_context)
orchestrator = BaseOrchestrator.get_orchestrator(
workflow=self,
user=user,
context=running_context,
)

self.action = action

if not hasattr(orchestrator, action):
raise NotImplementedError(
f"Orchestrator '{orchestrator_name}' does not implement action '{action}'"
f"Orchestrator '{self.profile.orchestrator_class}' does not implement action '{action}'"
)
result = getattr(orchestrator, action)(user_input)

Expand Down
67 changes: 56 additions & 11 deletions backend/openedx_ai_extensions/workflows/orchestrators.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,18 +66,20 @@ def _execute_orchestrator_async(task_self, session_id, action, params=None):
'location_id': str(session.location_id),
}

# 3. Get orchestrator class and instantiate from current module's globals
# 3. Resolve and instantiate orchestrator via centralized factory
orchestrator_name = session.profile.orchestrator_class
orchestrator_class = sys.modules[__name__].__dict__.get(orchestrator_name)
if not orchestrator_class:
error_msg = f"Orchestrator class '{orchestrator_name}' not found in module"
logger.error(f"Task {task_id}: {error_msg}")
raise AttributeError(error_msg)
orchestrator = orchestrator_class(
workflow=session.scope,
user=session.user,
context=context
)
try:
orchestrator = BaseOrchestrator.get_orchestrator(
workflow=session.scope,
user=session.user,
context=context,
)
except (AttributeError, TypeError) as exc:
logger.error(
f"Task {task_id}: Failed to resolve orchestrator: {exc}",
exc_info=True,
)
raise

# 4. Validate action exists
if not hasattr(orchestrator, action):
Expand Down Expand Up @@ -146,6 +148,49 @@ def _emit_workflow_event(self, event_name: str) -> None:
def run(self, input_data):
raise NotImplementedError("Subclasses must implement run method")

@classmethod
def get_orchestrator(cls, *, workflow, user, context):
"""
Resolve and instantiate an orchestrator for the given workflow.

This factory method centralizes orchestrator lookup and validation.
It ensures that the resolved class exists and is a subclass of
BaseOrchestrator, providing a single, consistent entry point
for orchestrator creation across the codebase.

Args:
workflow: AIWorkflowScope instance that defines the workflow configuration.
user: User for whom the workflow is being executed.
context: Dictionary with runtime context (e.g. course_id, location_id).

Returns:
BaseOrchestrator: An instantiated orchestrator for the given workflow.

Raises:
AttributeError: If the configured orchestrator class cannot be found.
TypeError: If the resolved class is not a subclass of BaseOrchestrator.
"""
orchestrator_name = workflow.profile.orchestrator_class

try:
module = sys.modules[__name__]
orchestrator_class = getattr(module, orchestrator_name)
except AttributeError as exc:
raise AttributeError(
f"Orchestrator class '{orchestrator_name}' not found"
) from exc

if not issubclass(orchestrator_class, BaseOrchestrator):
raise TypeError(
f"{orchestrator_name} is not a subclass of BaseOrchestrator"
)

return orchestrator_class(
workflow=workflow,
user=user,
context=context,
)


class MockResponse(BaseOrchestrator):
"""
Expand Down
171 changes: 171 additions & 0 deletions backend/tests/test_base_orchestrator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
"""
Tests for the BaseOrchestrator class in openedx-ai-extensions workflows module.
"""

from unittest.mock import patch

import pytest
from django.contrib.auth import get_user_model

from openedx_ai_extensions.workflows.orchestrators import BaseOrchestrator

User = get_user_model()


# ============================================================================
# Fixtures
# ============================================================================

@pytest.fixture
def mock_user(db): # pylint: disable=unused-argument
"""
Create a test user.
"""
return User.objects.create_user(
username="testuser2", email="test2@example.com", password="password123"
)


@pytest.fixture
def mock_workflow_profile():
"""
Create a fake workflow profile object with orchestrator_class attribute.
"""
class Profile:
slug = "mock-profile"
orchestrator_class = "MockOrchestrator"

return Profile()


@pytest.fixture
def mock_workflow(mock_workflow_profile): # pylint: disable=redefined-outer-name
"""
Create a fake workflow object with profile and action attributes.
"""
class Workflow:
id = 123
profile = mock_workflow_profile
action = "test_action"

return Workflow()


# ============================================================================
# BaseOrchestrator Initialization Tests
# ============================================================================

@pytest.mark.django_db
def test_base_orchestrator_init(mock_workflow, mock_user): # pylint: disable=redefined-outer-name
"""
Test that BaseOrchestrator initializes attributes correctly.
"""
context = {"location_id": "loc-1", "course_id": "course-1"}
orchestrator = BaseOrchestrator(workflow=mock_workflow, user=mock_user, context=context)

assert orchestrator.workflow == mock_workflow
assert orchestrator.user == mock_user
assert orchestrator.profile == mock_workflow.profile
assert orchestrator.location_id == "loc-1"
assert orchestrator.course_id == "course-1"


# ============================================================================
# _emit_workflow_event Tests
# ============================================================================

@pytest.mark.django_db
@patch("openedx_ai_extensions.workflows.orchestrators.tracker")
def test_emit_workflow_event(mock_tracker, mock_workflow, mock_user): # pylint: disable=redefined-outer-name
"""
Test that _emit_workflow_event calls tracker.emit with correct payload.
"""
context = {"location_id": "loc-1", "course_id": "course-1"}
orchestrator = BaseOrchestrator(workflow=mock_workflow, user=mock_user, context=context)

orchestrator._emit_workflow_event("TEST_EVENT") # pylint: disable=protected-access

mock_tracker.emit.assert_called_once_with("TEST_EVENT", {
"workflow_id": str(mock_workflow.id),
"action": mock_workflow.action,
"course_id": str("course-1"),
"profile_name": mock_workflow.profile.slug,
"location_id": str("loc-1"),
})


# ============================================================================
# run Method Tests
# ============================================================================

@pytest.mark.django_db
def test_base_orchestrator_run_raises_not_implemented(mock_workflow, mock_user): # pylint: disable=redefined-outer-name
"""
Test that calling run on BaseOrchestrator raises NotImplementedError.
"""
orchestrator = BaseOrchestrator(workflow=mock_workflow, user=mock_user, context={})
with pytest.raises(NotImplementedError):
orchestrator.run({})


# ============================================================================
# get_orchestrator Classmethod Tests
# ============================================================================

@pytest.mark.django_db
def test_get_orchestrator_success(monkeypatch, mock_workflow, mock_user): # pylint: disable=redefined-outer-name
"""
Test get_orchestrator returns an instance of the resolved class.
"""
from openedx_ai_extensions.workflows import orchestrators # pylint: disable=import-outside-toplevel

class MockOrchestrator(BaseOrchestrator):
def run(self, input_data):
return {"status": "ok"}

monkeypatch.setitem(orchestrators.__dict__, "MockOrchestrator", MockOrchestrator)

context = {"location_id": "loc-1", "course_id": "course-1"}
orchestrator = BaseOrchestrator.get_orchestrator(
workflow=mock_workflow,
user=mock_user,
context=context
)

assert isinstance(orchestrator, MockOrchestrator)
assert orchestrator.workflow == mock_workflow
assert orchestrator.user == mock_user


@pytest.mark.django_db
def test_get_orchestrator_attribute_error(mock_workflow, mock_user): # pylint: disable=redefined-outer-name
"""
Test get_orchestrator raises AttributeError when class does not exist.
"""
mock_workflow.profile.orchestrator_class = "NonExistingClass"
context = {"location_id": None, "course_id": None}

with pytest.raises(AttributeError) as exc_info:
BaseOrchestrator.get_orchestrator(workflow=mock_workflow, user=mock_user, context=context)

assert "NonExistingClass" in str(exc_info.value)


@pytest.mark.django_db
def test_get_orchestrator_type_error(monkeypatch, mock_workflow, mock_user): # pylint: disable=redefined-outer-name
"""
Test get_orchestrator raises TypeError when resolved class is not a subclass of BaseOrchestrator.
"""
from openedx_ai_extensions.workflows import orchestrators # pylint: disable=import-outside-toplevel

class NotAnOrchestrator:
pass

monkeypatch.setitem(orchestrators.__dict__, "MockOrchestrator", NotAnOrchestrator)

context = {"location_id": None, "course_id": None}

with pytest.raises(TypeError) as exc_info:
BaseOrchestrator.get_orchestrator(workflow=mock_workflow, user=mock_user, context=context)

assert "MockOrchestrator is not a subclass of BaseOrchestrator" in str(exc_info.value)