1212import pytest
1313import asyncio
1414from pathlib import Path
15- from typing import Dict , Any
16- from unittest .mock import Mock , AsyncMock , patch
15+ from typing import Dict , Any , List , Optional
1716
1817from orchestrator .compiler .yaml_compiler import YAMLCompiler , YAMLCompilerError
1918from orchestrator .control_systems .model_based_control_system import ModelBasedControlSystem
2019from orchestrator .models .model_registry import ModelRegistry
2120from orchestrator .core .pipeline import Pipeline
2221from orchestrator .core .task import Task , TaskStatus
22+ from orchestrator .core .model import Model , ModelCapabilities
2323
2424
2525class TestYAMLCompilation :
@@ -241,25 +241,87 @@ async def test_complex_template_expressions(self, compiler):
241241 assert list (pipeline .tasks .values ())[0 ].action .strip () == expected_action .strip ()
242242
243243
244+ class TestableDeclarativeModel (Model ):
245+ """A testable model for declarative framework tests."""
246+
247+ def __init__ (self , name = "test-model" , provider = "test" ):
248+ capabilities = ModelCapabilities (
249+ supported_tasks = ["generate" , "analyze" ],
250+ context_window = 8192 ,
251+ languages = ["en" ]
252+ )
253+ super ().__init__ (name = name , provider = provider , capabilities = capabilities )
254+ self .generate_calls : List [tuple ] = []
255+ self .generate_return_value = "Test result"
256+ self .generate_side_effect : Optional [Exception ] = None
257+ self .generate_delay : float = 0.0
258+ self .call_count = 0
259+
260+ async def generate (self , prompt , ** kwargs ):
261+ """Generate response."""
262+ self .generate_calls .append ((prompt , kwargs ))
263+ self .call_count += 1
264+
265+ # Simulate delay if configured
266+ if self .generate_delay > 0 :
267+ await asyncio .sleep (self .generate_delay )
268+
269+ # Raise exception if configured
270+ if self .generate_side_effect :
271+ raise self .generate_side_effect
272+
273+ # Return configured value or callable result
274+ if callable (self .generate_return_value ):
275+ return self .generate_return_value (self .call_count )
276+ return self .generate_return_value
277+
278+ async def generate_structured (self , prompt , schema , ** kwargs ):
279+ """Generate structured output."""
280+ result = await self .generate (prompt , ** kwargs )
281+ if isinstance (result , dict ):
282+ return result
283+ return {"result" : result }
284+
285+ async def validate_response (self , response , schema ):
286+ """Validate response."""
287+ return True
288+
289+ def estimate_tokens (self , text ):
290+ """Estimate tokens."""
291+ return len (text .split ())
292+
293+ def estimate_cost (self , input_tokens , output_tokens ):
294+ """Estimate cost."""
295+ return 0.001
296+
297+ async def health_check (self ):
298+ """Health check."""
299+ return True
300+
301+ def set_generate_return (self , value ):
302+ """Set the return value for generate."""
303+ self .generate_return_value = value
304+
305+ def set_generate_side_effect (self , exception ):
306+ """Set side effect (exception) for generate."""
307+ self .generate_side_effect = exception
308+
309+ def set_generate_delay (self , delay ):
310+ """Set delay for generate (in seconds)."""
311+ self .generate_delay = delay
312+
313+
244314class TestModelBasedControlSystem :
245315 """Test ModelBasedControlSystem execution."""
246316
247317 @pytest .fixture
248318 def mock_model (self ):
249- """Create a mock AI model."""
250- model = Mock ()
251- model .name = "test-model"
252- model .generate = AsyncMock (return_value = "Test result" )
253- model .capabilities = {
254- "tasks" : ["generate" , "analyze" ],
255- "context_window" : 8192 ,
256- "output_tokens" : 4096
257- }
258- return model
319+ """Create a testable AI model."""
320+ return TestableDeclarativeModel ()
259321
260322 @pytest .fixture
261323 def model_registry (self , mock_model ):
262- """Create a model registry with mock model."""
324+ """Create a model registry with test model."""
263325 registry = ModelRegistry ()
264326 registry .register_model (mock_model )
265327 return registry
@@ -312,7 +374,7 @@ async def test_context_propagation(self, control_system):
312374 async def test_error_handling (self , control_system , mock_model ):
313375 """Test error handling during task execution."""
314376 # Configure model to raise error
315- mock_model .generate . side_effect = Exception ("Model error" )
377+ mock_model .set_generate_side_effect ( Exception ("Model error" ) )
316378
317379 task = Task (
318380 id = "error_task" ,
@@ -334,16 +396,12 @@ async def test_error_handling(self, control_system, mock_model):
334396 async def test_retry_mechanism (self , control_system , mock_model ):
335397 """Test retry mechanism on failure."""
336398 # Configure model to fail twice then succeed
337- call_count = 0
338-
339- async def mock_generate (* args , ** kwargs ):
340- nonlocal call_count
341- call_count += 1
399+ def generate_with_retry (call_count ):
342400 if call_count < 3 :
343401 raise Exception ("Temporary failure" )
344402 return "Success after retries"
345403
346- mock_model .generate = mock_generate
404+ mock_model .set_generate_return ( generate_with_retry )
347405
348406 task = Task (
349407 id = "retry_task" ,
@@ -354,17 +412,14 @@ async def mock_generate(*args, **kwargs):
354412 result = await control_system .execute_task (task , {})
355413
356414 assert result == "Success after retries"
357- assert call_count == 3
415+ assert mock_model . call_count == 3
358416
359417 @pytest .mark .asyncio
360418 async def test_timeout_handling (self , control_system , mock_model ):
361419 """Test timeout handling."""
362420 # Configure model to take too long
363- async def slow_generate (* args , ** kwargs ):
364- await asyncio .sleep (5 )
365- return "Too late"
366-
367- mock_model .generate = slow_generate
421+ mock_model .set_generate_delay (5 ) # 5 second delay
422+ mock_model .set_generate_return ("Too late" )
368423
369424 task = Task (
370425 id = "timeout_task" ,
@@ -410,27 +465,26 @@ def setup(self):
410465 compiler = YAMLCompiler ()
411466 model_registry = ModelRegistry ()
412467
413- # Mock model that returns predictable results
414- model = Mock ()
415- model .name = "test-model"
416-
417- async def mock_generate (prompt , ** kwargs ):
418- if "research plan" in prompt .lower ():
419- return "Research plan: 1. Gather data 2. Analyze 3. Report"
420- elif "gather information" in prompt .lower ():
421- return "Information: Key facts about the topic"
422- elif "analyze" in prompt .lower ():
423- return "Analysis: Important insights discovered"
424- elif "report" in prompt .lower ():
425- return "Report: Comprehensive summary of findings"
426- return "Generic result"
427-
428- model .generate = mock_generate
429- model .capabilities = {
430- "tasks" : ["generate" , "analyze" ],
431- "context_window" : 8192
432- }
433-
468+ # Create testable model that returns predictable results
469+ class IntegrationTestModel (TestableDeclarativeModel ):
470+ async def generate (self , prompt , ** kwargs ):
471+ # Track the call
472+ self .generate_calls .append ((prompt , kwargs ))
473+ self .call_count += 1
474+
475+ # Return specific responses based on prompt
476+ prompt_lower = prompt .lower ()
477+ if "research plan" in prompt_lower :
478+ return "Research plan: 1. Gather data 2. Analyze 3. Report"
479+ elif "gather information" in prompt_lower :
480+ return "Information: Key facts about the topic"
481+ elif "analyze" in prompt_lower :
482+ return "Analysis: Important insights discovered"
483+ elif "report" in prompt_lower :
484+ return "Report: Comprehensive summary of findings"
485+ return "Generic result"
486+
487+ model = IntegrationTestModel ()
434488 model_registry .register_model (model )
435489 control_system = ModelBasedControlSystem (model_registry )
436490
@@ -525,16 +579,21 @@ async def test_error_recovery_pipeline(self, setup):
525579 result: "{{process_data.result}}"
526580"""
527581
528- # Mock the model to fail for primary source
582+ # Configure the model to fail for primary source
529583 model = list (control_system .model_registry ._models .values ())[0 ]
584+
585+ # Store original generate method
530586 original_generate = model .generate
531587
532- async def mock_generate_with_error (prompt , ** kwargs ):
588+ # Create a new bound method that includes error handling
589+ async def generate_with_error (prompt , ** kwargs ):
533590 if "primary source" in prompt :
534591 raise Exception ("Primary source unavailable" )
592+ # Call the original method bound to the model instance
535593 return await original_generate (prompt , ** kwargs )
536594
537- model .generate = mock_generate_with_error
595+ # Replace the generate method
596+ model .generate = generate_with_error
538597
539598 pipeline = await compiler .compile (yaml_content , {})
540599 results = await control_system .execute_pipeline (pipeline )
0 commit comments