|
1 | 1 | #!/usr/bin/env python3 |
2 | 2 | """ |
3 | 3 | Test script to verify multi-function call fix. |
| 4 | +
|
| 5 | +This script tests that the Gemini integration can handle multiple function calls |
| 6 | +in a single query, which was the key enhancement made to enable complex queries |
| 7 | +like "find madrigals in Florence" that require multiple entity lookups. |
| 8 | +
|
| 9 | +Usage: |
| 10 | + cd shared && poetry run python -m nlq2sparql.examples.tracing.test_multi_function_fix |
4 | 11 | """ |
5 | 12 |
|
6 | 13 | import asyncio |
7 | | -import json |
8 | 14 | import sys |
| 15 | +import os |
9 | 16 | from pathlib import Path |
10 | 17 |
|
11 | | -# Add parent directory to path for imports |
12 | | -sys.path.append(str(Path(__file__).parent.parent.parent)) |
| 18 | +# Ensure we can import the nlq2sparql module |
| 19 | +project_root = Path(__file__).parent.parent.parent.parent |
| 20 | +if str(project_root / "shared") not in sys.path: |
| 21 | + sys.path.insert(0, str(project_root / "shared")) |
| 22 | + |
| 23 | +try: |
| 24 | + from nlq2sparql.integrations.gemini_integration import GeminiWikidataIntegration |
| 25 | +except ImportError as e: |
| 26 | + print(f"❌ Import error: {e}") |
| 27 | + print("Make sure you're running from the shared/ directory with:") |
| 28 | + print("poetry run python -m nlq2sparql.examples.tracing.test_multi_function_fix") |
| 29 | + sys.exit(1) |
13 | 30 |
|
14 | | -from integrations.gemini_integration import GeminiWikidataIntegration |
15 | | -from tracing import get_tracer, export_trace_log |
16 | 31 |
|
17 | 32 | async def test_multi_function_calls(): |
18 | 33 | """Test that multiple function calls are executed properly.""" |
19 | 34 | print("🧪 Testing Multi-Function Call Fix") |
20 | 35 | print("=" * 50) |
21 | 36 |
|
22 | | - # Initialize tracing |
23 | | - tracer = get_tracer("multi_function_test") |
24 | | - |
25 | 37 | try: |
26 | 38 | # Initialize the integration |
27 | 39 | integration = GeminiWikidataIntegration() |
28 | 40 |
|
29 | 41 | # Test query that should trigger 2 function calls |
30 | | - query = "Find all madrigals written in Florence. First look up madrigal and Florence in Wikidata, then write a comprehensive SPARQL query that finds musical works of type madrigal that were composed in or associated with Florence, Italy. Include titles and composers." |
| 42 | + query = ("Find all madrigals written in Florence. First look up madrigal and Florence in Wikidata, " |
| 43 | + "then write a comprehensive SPARQL query that finds musical works of type madrigal that were " |
| 44 | + "composed in or associated with Florence, Italy. Include titles and composers.") |
31 | 45 |
|
32 | 46 | print(f"🔍 Query: {query[:100]}...") |
33 | 47 | print(f"📊 Expected function calls: 2 (madrigal + Florence)") |
34 | 48 | print() |
35 | 49 |
|
36 | | - # Execute the query with tracing |
37 | | - with tracer.trace_operation("multi_function_test"): |
38 | | - response = await integration.send_message_with_tools(query) |
39 | | - |
40 | | - # Analyze results |
41 | | - function_calls = response.get('function_calls', []) |
42 | | - |
43 | | - print(f"✅ Total function calls executed: {len(function_calls)}") |
44 | | - print() |
| 50 | + # Execute the query |
| 51 | + response = await integration.send_message_with_tools(query) |
| 52 | + |
| 53 | + # Analyze results |
| 54 | + function_calls = response.get('function_calls', []) |
| 55 | + |
| 56 | + print(f"✅ Total function calls executed: {len(function_calls)}") |
| 57 | + print() |
| 58 | + |
| 59 | + for i, call in enumerate(function_calls, 1): |
| 60 | + entity = call['arguments'].get('entity_label', 'unknown') |
| 61 | + result = call['result'] |
| 62 | + print(f" {i}. {call['function']}(\"{entity}\") → {result}") |
| 63 | + |
| 64 | + print() |
| 65 | + if len(function_calls) >= 2: |
| 66 | + print("🎉 SUCCESS: Multiple function calls executed!") |
| 67 | + print("✅ Fix verified - both madrigal and Florence lookups completed") |
45 | 68 |
|
46 | | - for i, call in enumerate(function_calls, 1): |
47 | | - entity = call['arguments'].get('entity_label', 'unknown') |
48 | | - result = call['result'] |
49 | | - print(f" {i}. {call['function']}(\"{entity}\") → {result}") |
| 69 | + # Check if we got QIDs for both |
| 70 | + results = [call['result'] for call in function_calls] |
| 71 | + qids = [r for r in results if isinstance(r, str) and r.startswith('Q')] |
50 | 72 |
|
51 | | - print() |
52 | | - if len(function_calls) >= 2: |
53 | | - print("🎉 SUCCESS: Multiple function calls executed!") |
54 | | - print("✅ Fix verified - both madrigal and Florence lookups completed") |
55 | | - |
56 | | - # Check if we got QIDs for both |
57 | | - results = [call['result'] for call in function_calls] |
58 | | - qids = [r for r in results if isinstance(r, str) and r.startswith('Q')] |
59 | | - |
60 | | - if len(qids) >= 2: |
61 | | - print(f"🔗 Entity QIDs resolved: {qids}") |
62 | | - print("✅ Ready for SPARQL generation with both entities") |
63 | | - else: |
64 | | - print(f"⚠️ Some lookups may have failed: {results}") |
65 | | - |
| 73 | + if len(qids) >= 2: |
| 74 | + print(f"🔗 Entity QIDs resolved: {qids}") |
| 75 | + print("✅ Ready for SPARQL generation with both entities") |
66 | 76 | else: |
67 | | - print("❌ FAILURE: Still only executing single function call") |
68 | | - print("🔍 Need to investigate further...") |
69 | | - |
70 | | - print() |
71 | | - print(f"📝 Final response: {response['text'][:200]}...") |
72 | | - |
73 | | - # Save detailed trace |
74 | | - trace_file = export_trace_log("../../../logs/multi_function_test_trace.json") |
75 | | - print(f"💾 Detailed trace saved to: {trace_file}") |
76 | | - |
77 | | - print() |
78 | | - if len(function_calls) >= 2: |
79 | | - print("🎉 SUCCESS: Multiple function calls executed!") |
80 | | - print("✅ Fix verified - both madrigal and Florence lookups completed") |
| 77 | + print(f"⚠️ Some lookups may have failed: {results}") |
81 | 78 |
|
82 | | - # Check if we got QIDs for both |
83 | | - results = [call['result'] for call in function_calls] |
84 | | - qids = [r for r in results if isinstance(r, str) and r.startswith('Q')] |
85 | | - |
86 | | - if len(qids) >= 2: |
87 | | - print(f"🔗 Entity QIDs resolved: {qids}") |
88 | | - print("✅ Ready for SPARQL generation with both entities") |
89 | | - else: |
90 | | - print(f"⚠️ Some lookups may have failed: {results}") |
91 | | - |
92 | | - else: |
93 | | - print("❌ FAILURE: Still only executing single function call") |
94 | | - print("🔍 Need to investigate further...") |
95 | | - |
96 | | - print() |
97 | | - print(f"📝 Final response: {response['text'][:200]}...") |
98 | | - |
99 | | - # Save detailed trace |
100 | | - trace_data = tracer.export_trace() |
101 | | - trace_file = Path("../../../logs/multi_function_test_trace.json") |
102 | | - trace_file.write_text(json.dumps(trace_data, indent=2)) |
103 | | - print(f"💾 Detailed trace saved to: {trace_file}") |
| 79 | + else: |
| 80 | + print("❌ FAILURE: Still only executing single function call") |
| 81 | + print("🔍 Need to investigate further...") |
| 82 | + |
| 83 | + print() |
| 84 | + if response['text']: |
| 85 | + print(f"📝 Final response preview: {response['text'][:200]}...") |
| 86 | + else: |
| 87 | + print("📝 No text response (function calls only)") |
| 88 | + |
| 89 | + return len(function_calls) >= 2 |
104 | 90 |
|
105 | 91 | except Exception as e: |
106 | 92 | print(f"❌ Error during test: {e}") |
107 | 93 | import traceback |
108 | 94 | traceback.print_exc() |
| 95 | + return False |
| 96 | + |
109 | 97 |
|
110 | 98 | async def main(): |
111 | | - await test_multi_function_calls() |
| 99 | + """Main test function""" |
| 100 | + print("🔬 Multi-Function Call Test Suite") |
| 101 | + print("=" * 60) |
| 102 | + print() |
| 103 | + |
| 104 | + success = await test_multi_function_calls() |
| 105 | + |
| 106 | + print() |
| 107 | + print("� Test Results:") |
| 108 | + print("=" * 20) |
| 109 | + if success: |
| 110 | + print("✅ PASSED: Multi-function call processing works correctly") |
| 111 | + print("🎯 System can handle complex queries requiring multiple entity lookups") |
| 112 | + else: |
| 113 | + print("❌ FAILED: Multi-function call processing needs investigation") |
| 114 | + |
| 115 | + return success |
| 116 | + |
112 | 117 |
|
113 | 118 | if __name__ == "__main__": |
114 | | - asyncio.run(main()) |
| 119 | + result = asyncio.run(main()) |
| 120 | + sys.exit(0 if result else 1) |
0 commit comments