1
+ """
2
+ Unit tests for CoroutineChecker utility class.
3
+
4
+ Focused test suite covering core functionality and main edge cases.
5
+ """
6
+
7
+ import pytest
8
+ from unittest .mock import patch
9
+
10
+ from litellm .litellm_core_utils .coroutine_checker import CoroutineChecker , coroutine_checker
11
+
12
+
13
+ class TestCoroutineChecker :
14
+ """Test cases for CoroutineChecker class."""
15
+
16
+ def setup_method (self ):
17
+ """Set up test fixtures before each test method."""
18
+ self .checker = CoroutineChecker ()
19
+
20
+ def test_init (self ):
21
+ """Test CoroutineChecker initialization."""
22
+ checker = CoroutineChecker ()
23
+ assert isinstance (checker , CoroutineChecker )
24
+
25
+ @pytest .mark .parametrize ("obj,expected,description" , [
26
+ # Basic function types
27
+ (lambda : "sync" , False , "sync lambda" ),
28
+ (len , False , "built-in function" ),
29
+ # Non-callable objects
30
+ ("string" , False , "string" ),
31
+ (123 , False , "integer" ),
32
+ ([], False , "list" ),
33
+ ({}, False , "dict" ),
34
+ (None , False , "None" ),
35
+ ])
36
+ def test_is_async_callable_basic_and_non_callable (self , obj , expected , description ):
37
+ """Test is_async_callable with basic types and non-callable objects."""
38
+ assert self .checker .is_async_callable (obj ) is expected , f"Failed for { description } : { obj } "
39
+
40
+ def test_is_async_callable_async_and_sync_callables (self ):
41
+ """Test is_async_callable with various async and sync callable types."""
42
+ # Async and sync functions
43
+ async def async_func ():
44
+ return "async"
45
+
46
+ def sync_func ():
47
+ return "sync"
48
+
49
+ # Class methods
50
+ class TestClass :
51
+ def sync_method (self ):
52
+ return "sync"
53
+
54
+ async def async_method (self ):
55
+ return "async"
56
+
57
+ obj = TestClass ()
58
+
59
+ # Callable objects
60
+ class SyncCallable :
61
+ def __call__ (self ):
62
+ return "sync"
63
+
64
+ class AsyncCallable :
65
+ async def __call__ (self ):
66
+ return "async"
67
+
68
+ # Test all async callables
69
+ assert self .checker .is_async_callable (async_func ) is True
70
+ assert self .checker .is_async_callable (obj .async_method ) is True
71
+ assert self .checker .is_async_callable (AsyncCallable ()) is True
72
+
73
+ # Test all sync callables
74
+ assert self .checker .is_async_callable (sync_func ) is False
75
+ assert self .checker .is_async_callable (obj .sync_method ) is False
76
+ assert self .checker .is_async_callable (SyncCallable ()) is False
77
+
78
+ def test_is_async_callable_caching (self ):
79
+ """Test that is_async_callable caches callable objects."""
80
+ async def async_func ():
81
+ return "async"
82
+
83
+ # Test that it works correctly
84
+ result1 = self .checker .is_async_callable (async_func )
85
+ assert result1 is True
86
+
87
+ # Test that callable objects are cached
88
+ assert async_func in self .checker ._cache
89
+ assert self .checker ._cache [async_func ] is True
90
+
91
+ # Test that it works consistently
92
+ result2 = self .checker .is_async_callable (async_func )
93
+ assert result2 is True
94
+
95
+ def test_edge_cases_and_error_handling (self ):
96
+ """Test edge cases and error handling."""
97
+ from functools import partial
98
+
99
+ # Error handling cases
100
+ class ProblematicCallable :
101
+ def __getattr__ (self , name ):
102
+ if name == "__call__" :
103
+ raise Exception ("Cannot access __call__" )
104
+ raise AttributeError (f"'{ self .__class__ .__name__ } ' object has no attribute '{ name } '" )
105
+
106
+ class UnstringableCallable :
107
+ def __str__ (self ):
108
+ raise Exception ("Cannot convert to string" )
109
+
110
+ async def __call__ (self ):
111
+ return "async"
112
+
113
+ # Generator functions
114
+ def sync_generator ():
115
+ yield "sync"
116
+
117
+ async def async_generator ():
118
+ yield "async"
119
+
120
+ # Partial functions
121
+ def sync_func (x , y ):
122
+ return x + y
123
+
124
+ async def async_func (x , y ):
125
+ return x + y
126
+
127
+ sync_partial = partial (sync_func , 1 )
128
+ async_partial = partial (async_func , 1 )
129
+
130
+ # Test error handling
131
+ assert self .checker .is_async_callable (ProblematicCallable ()) is False
132
+ assert self .checker .is_async_callable (UnstringableCallable ()) is True
133
+
134
+ # Test generators (both sync and async generators are not coroutine functions)
135
+ assert self .checker .is_async_callable (sync_generator ) is False
136
+ assert self .checker .is_async_callable (async_generator ) is False
137
+
138
+ # Test partial functions (don't preserve coroutine nature)
139
+ assert self .checker .is_async_callable (sync_partial ) is False
140
+ assert self .checker .is_async_callable (async_partial ) is False
141
+
142
+ def test_error_handling_in_inspect (self ):
143
+ """Test error handling when inspect.iscoroutinefunction raises exception."""
144
+ with patch ('inspect.iscoroutinefunction' , side_effect = Exception ("Inspect error" )):
145
+ async def async_func ():
146
+ return "async"
147
+
148
+ # Should return False when inspect raises exception
149
+ assert self .checker .is_async_callable (async_func ) is False
150
+
151
+ def test_global_coroutine_checker_instance (self ):
152
+ """Test the global coroutine_checker instance."""
153
+ assert isinstance (coroutine_checker , CoroutineChecker )
154
+
155
+ async def async_func ():
156
+ return "async"
157
+
158
+ def sync_func ():
159
+ return "sync"
160
+
161
+ assert coroutine_checker .is_async_callable (async_func ) is True
162
+ assert coroutine_checker .is_async_callable (sync_func ) is False
0 commit comments