Skip to content

Commit 6943513

Browse files
committed
new tests
1 parent c7aa925 commit 6943513

File tree

1 file changed

+200
-27
lines changed

1 file changed

+200
-27
lines changed

lib/crewai/tests/test_async_human_feedback.py

Lines changed: 200 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -274,18 +274,6 @@ def test_provider_initialization(self) -> None:
274274
quiet_provider = ConsoleProvider(verbose=False)
275275
assert quiet_provider.verbose is False
276276

277-
def test_provider_is_instance_of_protocol(self) -> None:
278-
"""Test that ConsoleProvider implements the protocol."""
279-
provider = ConsoleProvider()
280-
assert isinstance(provider, HumanFeedbackProvider)
281-
282-
def test_provider_has_verbose_attribute(self) -> None:
283-
"""Test that provider has verbose attribute."""
284-
provider = ConsoleProvider(verbose=True)
285-
assert provider.verbose is True
286-
287-
provider2 = ConsoleProvider(verbose=False)
288-
assert provider2.verbose is False
289277

290278

291279
# =============================================================================
@@ -539,6 +527,81 @@ def begin(self):
539527
with pytest.raises(ValueError, match="No pending feedback context"):
540528
flow.resume("some feedback")
541529

530+
def test_resume_from_async_context_raises_error(self) -> None:
531+
"""Test that resume() raises RuntimeError when called from async context."""
532+
import asyncio
533+
534+
class TestFlow(Flow):
535+
@start()
536+
def begin(self):
537+
return "started"
538+
539+
async def call_resume_from_async():
540+
with tempfile.TemporaryDirectory() as tmpdir:
541+
db_path = os.path.join(tmpdir, "test.db")
542+
persistence = SQLiteFlowPersistence(db_path)
543+
544+
# Save pending feedback
545+
context = PendingFeedbackContext(
546+
flow_id="async-context-test",
547+
flow_class="TestFlow",
548+
method_name="begin",
549+
method_output="output",
550+
message="Review:",
551+
)
552+
persistence.save_pending_feedback(
553+
flow_uuid="async-context-test",
554+
context=context,
555+
state_data={"id": "async-context-test"},
556+
)
557+
558+
flow = TestFlow.from_pending("async-context-test", persistence)
559+
560+
# This should raise RuntimeError because we're in an async context
561+
with pytest.raises(RuntimeError, match="cannot be called from within an async context"):
562+
flow.resume("feedback")
563+
564+
asyncio.run(call_resume_from_async())
565+
566+
@pytest.mark.asyncio
567+
async def test_resume_async_direct(self) -> None:
568+
"""Test resume_async() can be called directly in async context."""
569+
with tempfile.TemporaryDirectory() as tmpdir:
570+
db_path = os.path.join(tmpdir, "test.db")
571+
persistence = SQLiteFlowPersistence(db_path)
572+
573+
class TestFlow(Flow):
574+
@start()
575+
@human_feedback(message="Review:")
576+
def generate(self):
577+
return "content"
578+
579+
@listen(generate)
580+
def process(self, result):
581+
return f"processed: {result.feedback}"
582+
583+
# Save pending feedback
584+
context = PendingFeedbackContext(
585+
flow_id="async-direct-test",
586+
flow_class="TestFlow",
587+
method_name="generate",
588+
method_output="content",
589+
message="Review:",
590+
)
591+
persistence.save_pending_feedback(
592+
flow_uuid="async-direct-test",
593+
context=context,
594+
state_data={"id": "async-direct-test"},
595+
)
596+
597+
flow = TestFlow.from_pending("async-direct-test", persistence)
598+
599+
with patch("crewai.flow.flow.crewai_event_bus.emit"):
600+
result = await flow.resume_async("async feedback")
601+
602+
assert flow.last_human_feedback is not None
603+
assert flow.last_human_feedback.feedback == "async feedback"
604+
542605
@patch("crewai.flow.flow.crewai_event_bus.emit")
543606
def test_resume_basic(self, mock_emit: MagicMock) -> None:
544607
"""Test basic resume functionality."""
@@ -783,6 +846,131 @@ def process(self, feedback_result):
783846
# =============================================================================
784847

785848

849+
class TestAutoPersistence:
850+
"""Tests for automatic persistence when no persistence is provided."""
851+
852+
@patch("crewai.flow.flow.crewai_event_bus.emit")
853+
def test_auto_persistence_when_none_provided(self, mock_emit: MagicMock) -> None:
854+
"""Test that persistence is auto-created when HumanFeedbackPending is raised."""
855+
856+
class PausingProvider:
857+
def request_feedback(
858+
self, context: PendingFeedbackContext, flow: Flow
859+
) -> str:
860+
raise HumanFeedbackPending(
861+
context=context,
862+
callback_info={"paused": True},
863+
)
864+
865+
class TestFlow(Flow):
866+
@start()
867+
@human_feedback(
868+
message="Review:",
869+
provider=PausingProvider(),
870+
)
871+
def generate(self):
872+
return "content"
873+
874+
# Create flow WITHOUT persistence
875+
flow = TestFlow()
876+
assert flow._persistence is None # No persistence initially
877+
878+
# kickoff should auto-create persistence when HumanFeedbackPending is raised
879+
result = flow.kickoff()
880+
881+
# Should return HumanFeedbackPending (not raise it)
882+
assert isinstance(result, HumanFeedbackPending)
883+
884+
# Persistence should have been auto-created
885+
assert flow._persistence is not None
886+
887+
# The pending feedback should be saved
888+
flow_id = result.context.flow_id
889+
loaded = flow._persistence.load_pending_feedback(flow_id)
890+
assert loaded is not None
891+
892+
893+
class TestCollapseToOutcomeJsonParsing:
894+
"""Tests for _collapse_to_outcome JSON parsing edge cases."""
895+
896+
def test_json_string_response_is_parsed(self) -> None:
897+
"""Test that JSON string response from LLM is correctly parsed."""
898+
flow = Flow()
899+
900+
with patch("crewai.llm.LLM") as MockLLM:
901+
mock_llm = MagicMock()
902+
# Simulate LLM returning JSON string (the bug we fixed)
903+
mock_llm.call.return_value = '{"outcome": "approved"}'
904+
MockLLM.return_value = mock_llm
905+
906+
result = flow._collapse_to_outcome(
907+
feedback="I approve this",
908+
outcomes=["approved", "rejected"],
909+
llm="gpt-4o-mini",
910+
)
911+
912+
assert result == "approved"
913+
914+
def test_plain_string_response_is_matched(self) -> None:
915+
"""Test that plain string response is correctly matched."""
916+
flow = Flow()
917+
918+
with patch("crewai.llm.LLM") as MockLLM:
919+
mock_llm = MagicMock()
920+
# Simulate LLM returning plain outcome string
921+
mock_llm.call.return_value = "rejected"
922+
MockLLM.return_value = mock_llm
923+
924+
result = flow._collapse_to_outcome(
925+
feedback="This is not good",
926+
outcomes=["approved", "rejected"],
927+
llm="gpt-4o-mini",
928+
)
929+
930+
assert result == "rejected"
931+
932+
def test_invalid_json_falls_back_to_matching(self) -> None:
933+
"""Test that invalid JSON falls back to string matching."""
934+
flow = Flow()
935+
936+
with patch("crewai.llm.LLM") as MockLLM:
937+
mock_llm = MagicMock()
938+
# Invalid JSON that contains "approved"
939+
mock_llm.call.return_value = "{invalid json but says approved"
940+
MockLLM.return_value = mock_llm
941+
942+
result = flow._collapse_to_outcome(
943+
feedback="looks good",
944+
outcomes=["approved", "rejected"],
945+
llm="gpt-4o-mini",
946+
)
947+
948+
assert result == "approved"
949+
950+
def test_llm_exception_falls_back_to_simple_prompting(self) -> None:
951+
"""Test that LLM exception triggers fallback to simple prompting."""
952+
flow = Flow()
953+
954+
with patch("crewai.llm.LLM") as MockLLM:
955+
mock_llm = MagicMock()
956+
# First call raises, second call succeeds (fallback)
957+
mock_llm.call.side_effect = [
958+
Exception("Structured output failed"),
959+
"approved",
960+
]
961+
MockLLM.return_value = mock_llm
962+
963+
result = flow._collapse_to_outcome(
964+
feedback="I approve",
965+
outcomes=["approved", "rejected"],
966+
llm="gpt-4o-mini",
967+
)
968+
969+
assert result == "approved"
970+
# Verify it was called twice (initial + fallback)
971+
assert mock_llm.call.call_count == 2
972+
973+
786974
class TestAsyncHumanFeedbackEdgeCases:
787975
"""Edge case tests for async human feedback."""
788976

@@ -879,18 +1067,3 @@ def step(self):
8791067

8801068
assert flow.last_human_feedback.outcome == "approved"
8811069
assert flow.last_human_feedback.feedback == ""
882-
883-
def test_provider_is_protocol_not_base_class(self) -> None:
884-
"""Test that provider uses Protocol, not inheritance."""
885-
# This should work because Protocol uses structural typing
886-
887-
class CustomProvider:
888-
"""A provider that doesn't explicitly inherit from HumanFeedbackProvider."""
889-
890-
def request_feedback(
891-
self, context: PendingFeedbackContext, flow: Flow
892-
) -> str:
893-
return "feedback from custom provider"
894-
895-
provider = CustomProvider()
896-
assert isinstance(provider, HumanFeedbackProvider)

0 commit comments

Comments
 (0)