44from pathlib import Path
55from typing import Dict , Any , Optional
66import yaml
7- from unittest .mock import Mock , patch , AsyncMock
87
98from orchestrator import Orchestrator
9+ from orchestrator .models .model_registry import ModelRegistry
10+ from orchestrator .tools .registry import ToolRegistry
1011
1112
1213class BaseExampleTest :
@@ -37,43 +38,71 @@ def load_yaml_pipeline(self, pipeline_name: str) -> Dict[str, Any]:
3738
3839 @pytest .fixture
3940 def mock_model_registry (self ):
40- """Mock model registry for tests."""
41- with patch ('orchestrator.models.registry.ModelRegistry' ) as mock :
42- registry = Mock ()
43-
44- # Mock model resolution
45- async def mock_resolve (model_spec ):
41+ """Create test model registry."""
42+ # Create a test registry
43+ class TestModelRegistry (ModelRegistry ):
44+ def __init__ (self ):
45+ super ().__init__ ()
46+ self .resolve_calls = []
47+
48+ async def resolve_model (self , model_spec ):
49+ """Test model resolution."""
50+ self .resolve_calls .append (model_spec )
4651 return {
47- "provider" : "mock " ,
48- "model" : "mock -model" ,
52+ "provider" : "test " ,
53+ "model" : "test -model" ,
4954 "temperature" : 0.7
5055 }
51-
52- registry .resolve_model = AsyncMock (side_effect = mock_resolve )
53- mock .return_value = registry
54- yield registry
56+
57+ # Store original registry class
58+ import orchestrator .models .registry
59+ original_registry = orchestrator .models .registry .ModelRegistry
60+
61+ # Replace with test registry
62+ test_registry = TestModelRegistry ()
63+ orchestrator .models .registry .ModelRegistry = lambda : test_registry
64+
65+ yield test_registry
66+
67+ # Restore original registry
68+ orchestrator .models .registry .ModelRegistry = original_registry
5569
5670 @pytest .fixture
5771 def mock_tool_registry (self ):
58- """Mock tool registry for tests."""
59- with patch ('orchestrator.tools.registry.ToolRegistry' ) as mock :
60- registry = Mock ()
61-
62- # Mock tool discovery
63- async def mock_discover (action_desc ):
72+ """Create test tool registry."""
73+ # Create a test registry
74+ class TestToolRegistry (ToolRegistry ):
75+ def __init__ (self ):
76+ super ().__init__ ()
77+ self .discover_calls = []
78+
79+ async def discover_tool (self , action_desc ):
80+ """Test tool discovery."""
81+ self .discover_calls .append (action_desc )
82+
6483 # Simple tool mapping based on keywords
65- if "search" in action_desc .lower () or "web" in action_desc .lower ():
84+ action_lower = action_desc .lower ()
85+ if "search" in action_lower or "web" in action_lower :
6686 return {"tool" : "web_search" , "params" : {}}
67- elif "analyze" in action_desc . lower () :
87+ elif "analyze" in action_lower :
6888 return {"tool" : "analyzer" , "params" : {}}
69- elif "generate" in action_desc . lower () :
89+ elif "generate" in action_lower :
7090 return {"tool" : "generator" , "params" : {}}
7191 else :
7292 return {"tool" : "generic" , "params" : {}}
73-
74- registry .discover_tool = AsyncMock (side_effect = mock_discover )
75- mock .return_value = registry
76- yield registry
93+
94+ # Store original registry class
95+ import orchestrator .tools .registry
96+ original_registry = orchestrator .tools .registry .ToolRegistry
97+
98+ # Replace with test registry
99+ test_registry = TestToolRegistry ()
100+ orchestrator .tools .registry .ToolRegistry = lambda : test_registry
101+
102+ yield test_registry
103+
104+ # Restore original registry
105+ orchestrator .tools .registry .ToolRegistry = original_registry
77106
78107 async def run_pipeline_test (
79108 self ,
@@ -87,36 +116,48 @@ async def run_pipeline_test(
87116 # Load pipeline
88117 pipeline_config = self .load_yaml_pipeline (pipeline_name )
89118
90- # Mock tool executions if provided
119+ # Configure test responses if provided
91120 if mock_responses :
92- for step_id , response in mock_responses .items ():
93- # Mock the step execution
94- with patch .object (
95- orchestrator ,
96- '_execute_step' ,
97- new_callable = AsyncMock
98- ) as mock_exec :
99- mock_exec .return_value = response
100-
101- # Run pipeline
102- result = await orchestrator .execute_yaml (
103- yaml .dump (pipeline_config ),
104- context = inputs
105- )
106-
107- # Validate outputs if expected
108- if expected_outputs :
109- for key , expected_value in expected_outputs .items ():
110- assert key in result ['outputs' ]
111- if isinstance (expected_value , dict ):
112- # For complex objects, check structure
113- assert isinstance (result ['outputs' ][key ], dict )
114- for sub_key in expected_value :
115- assert sub_key in result ['outputs' ][key ]
116- else :
117- assert result ['outputs' ][key ] == expected_value
121+ # Store original execute_step method
122+ original_execute_step = orchestrator ._execute_step if hasattr (orchestrator , '_execute_step' ) else None
123+
124+ # Create a test execute_step that returns configured responses
125+ async def test_execute_step (step , context ):
126+ step_id = step .id if hasattr (step , 'id' ) else step .get ('id' )
127+ if step_id in mock_responses :
128+ return mock_responses [step_id ]
129+ # Fall back to original if available
130+ if original_execute_step :
131+ return await original_execute_step (step , context )
132+ return {}
133+
134+ # Replace the method
135+ orchestrator ._execute_step = test_execute_step
118136
119- return result
137+ try :
138+ # Run pipeline
139+ result = await orchestrator .execute_yaml (
140+ yaml .dump (pipeline_config ),
141+ context = inputs
142+ )
143+
144+ # Validate outputs if expected
145+ if expected_outputs :
146+ for key , expected_value in expected_outputs .items ():
147+ assert key in result ['outputs' ]
148+ if isinstance (expected_value , dict ):
149+ # For complex objects, check structure
150+ assert isinstance (result ['outputs' ][key ], dict )
151+ for sub_key in expected_value :
152+ assert sub_key in result ['outputs' ][key ]
153+ else :
154+ assert result ['outputs' ][key ] == expected_value
155+
156+ return result
157+ finally :
158+ # Restore original method if we replaced it
159+ if mock_responses and original_execute_step :
160+ orchestrator ._execute_step = original_execute_step
120161
121162 def validate_pipeline_structure (self , pipeline_name : str ):
122163 """Validate basic pipeline structure."""
0 commit comments