Skip to content

Commit f6e44a8

Browse files
committed
Add tests
1 parent 576e4a5 commit f6e44a8

File tree

1 file changed

+315
-0
lines changed

1 file changed

+315
-0
lines changed
Lines changed: 315 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,315 @@
1+
# Copyright The OpenTelemetry Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
17+
from opentelemetry.util._patch import (
18+
_get_all_subclasses,
19+
_get_leaf_subclasses,
20+
patch_leaf_subclasses,
21+
)
22+
23+
24+
class TestPatchFunctionality(unittest.TestCase):
25+
"""Test cases for the patching functionality in _patch.py"""
26+
27+
def setUp(self):
28+
"""Set up test classes for each test case"""
29+
# Create a fresh set of test classes for each test
30+
# to avoid interference between tests
31+
32+
class BaseClass:
33+
def target_method(self):
34+
return "base"
35+
36+
def other_method(self):
37+
return "base_other"
38+
39+
class IntermediateClass(BaseClass):
40+
def target_method(self):
41+
return "intermediate"
42+
43+
class LeafClass1(IntermediateClass):
44+
def target_method(self):
45+
return "leaf1"
46+
47+
class LeafClass2(IntermediateClass):
48+
def target_method(self):
49+
return "leaf2"
50+
51+
class LeafClass3(BaseClass):
52+
def target_method(self):
53+
return "leaf3"
54+
55+
class LeafClassWithoutMethod(BaseClass):
56+
pass # Inherits target_method from BaseClass
57+
58+
class LeafClassWithNonCallable(BaseClass):
59+
target_method = "not_callable"
60+
61+
self.BaseClass = BaseClass
62+
self.IntermediateClass = IntermediateClass
63+
self.LeafClass1 = LeafClass1
64+
self.LeafClass2 = LeafClass2
65+
self.LeafClass3 = LeafClass3
66+
self.LeafClassWithoutMethod = LeafClassWithoutMethod
67+
self.LeafClassWithNonCallable = LeafClassWithNonCallable
68+
69+
def test_get_all_subclasses_single_level(self):
70+
"""Test _get_all_subclasses with single inheritance level"""
71+
subclasses = _get_all_subclasses(self.BaseClass)
72+
expected_subclasses = {
73+
self.IntermediateClass,
74+
self.LeafClass1,
75+
self.LeafClass2,
76+
self.LeafClass3,
77+
self.LeafClassWithoutMethod,
78+
self.LeafClassWithNonCallable,
79+
}
80+
self.assertEqual(subclasses, expected_subclasses)
81+
82+
def test_get_all_subclasses_intermediate_class(self):
83+
"""Test _get_all_subclasses with intermediate class"""
84+
subclasses = _get_all_subclasses(self.IntermediateClass)
85+
expected_subclasses = {self.LeafClass1, self.LeafClass2}
86+
self.assertEqual(subclasses, expected_subclasses)
87+
88+
def test_get_all_subclasses_leaf_class(self):
89+
"""Test _get_all_subclasses with leaf class (no subclasses)"""
90+
subclasses = _get_all_subclasses(self.LeafClass1)
91+
self.assertEqual(subclasses, set())
92+
93+
def test_get_leaf_subclasses(self):
94+
"""Test _get_leaf_subclasses correctly identifies leaf classes"""
95+
all_subclasses = _get_all_subclasses(self.BaseClass)
96+
leaf_subclasses = _get_leaf_subclasses(all_subclasses)
97+
98+
expected_leaf_classes = {
99+
self.LeafClass1,
100+
self.LeafClass2,
101+
self.LeafClass3,
102+
self.LeafClassWithoutMethod,
103+
self.LeafClassWithNonCallable,
104+
}
105+
self.assertEqual(leaf_subclasses, expected_leaf_classes)
106+
107+
def test_get_leaf_subclasses_empty_set(self):
108+
"""Test _get_leaf_subclasses with empty set"""
109+
leaf_subclasses = _get_leaf_subclasses(set())
110+
self.assertEqual(leaf_subclasses, set())
111+
112+
def test_get_leaf_subclasses_single_class(self):
113+
"""Test _get_leaf_subclasses with single class"""
114+
leaf_subclasses = _get_leaf_subclasses({self.LeafClass1})
115+
self.assertEqual(leaf_subclasses, {self.LeafClass1})
116+
117+
def test_patch_leaf_subclasses_basic(self):
118+
"""Test basic patching functionality"""
119+
call_tracker = []
120+
121+
def wrapper(original_method):
122+
def wrapped(*args, **kwargs):
123+
call_tracker.append(f"wrapped_{original_method.__name__}")
124+
res = original_method(*args, **kwargs)
125+
return f"wrapped_{res}"
126+
127+
return wrapped
128+
129+
# Apply patch
130+
patch_leaf_subclasses(self.BaseClass, "target_method", wrapper)
131+
132+
# Test that leaf classes are patched
133+
leaf1_instance = self.LeafClass1()
134+
leaf2_instance = self.LeafClass2()
135+
leaf3_instance = self.LeafClass3()
136+
leaf_without_method_instance = self.LeafClassWithoutMethod()
137+
138+
# Check results
139+
self.assertEqual(leaf1_instance.target_method(), "wrapped_leaf1")
140+
self.assertEqual(leaf2_instance.target_method(), "wrapped_leaf2")
141+
self.assertEqual(leaf3_instance.target_method(), "wrapped_leaf3")
142+
self.assertEqual(
143+
leaf_without_method_instance.target_method(), "wrapped_base"
144+
)
145+
146+
# Check that wrapper was called
147+
expected_calls = ["wrapped_target_method"] * 4
148+
self.assertEqual(call_tracker, expected_calls)
149+
150+
# Test that intermediate class is NOT patched
151+
intermediate_instance = self.IntermediateClass()
152+
call_tracker.clear()
153+
result = intermediate_instance.target_method()
154+
self.assertEqual(result, "intermediate")
155+
self.assertEqual(call_tracker, []) # No wrapper calls
156+
157+
def test_patch_leaf_subclasses_non_callable_attribute(self):
158+
"""Test that non-callable attributes are not patched"""
159+
160+
def wrapper(original_method):
161+
def wrapped(*args, **kwargs):
162+
return "wrapped"
163+
164+
return wrapped
165+
166+
# Apply patch
167+
patch_leaf_subclasses(self.BaseClass, "target_method", wrapper)
168+
169+
# Test that class with non-callable attribute is not patched
170+
leaf_non_callable_instance = self.LeafClassWithNonCallable()
171+
self.assertEqual(
172+
leaf_non_callable_instance.target_method, "not_callable"
173+
)
174+
175+
def test_patch_leaf_subclasses_nonexistent_method(self):
176+
"""Test patching a method that doesn't exist"""
177+
178+
def wrapper(original_method):
179+
def wrapped(*args, **kwargs):
180+
return "wrapped"
181+
182+
return wrapped
183+
184+
# This should not raise an exception
185+
patch_leaf_subclasses(self.BaseClass, "nonexistent_method", wrapper)
186+
187+
# Verify that instances still work normally
188+
leaf1_instance = self.LeafClass1()
189+
self.assertEqual(leaf1_instance.target_method(), "leaf1")
190+
191+
def test_patch_leaf_subclasses_preserves_original_behavior(self):
192+
"""Test that patching preserves the original method behavior"""
193+
194+
def identity_wrapper(original_method):
195+
def wrapped(*args, **kwargs):
196+
return original_method(*args, **kwargs)
197+
198+
return wrapped
199+
200+
# Apply patch
201+
patch_leaf_subclasses(
202+
self.BaseClass, "target_method", identity_wrapper
203+
)
204+
205+
# Test that behavior is preserved
206+
leaf1_instance = self.LeafClass1()
207+
leaf2_instance = self.LeafClass2()
208+
209+
self.assertEqual(leaf1_instance.target_method(), "leaf1")
210+
self.assertEqual(leaf2_instance.target_method(), "leaf2")
211+
212+
def test_patch_leaf_subclasses_with_arguments(self):
213+
"""Test patching methods that take arguments"""
214+
215+
class TestClassWithArgs:
216+
def method_with_args(self, x, y=10):
217+
return x + y
218+
219+
class ChildClass(TestClassWithArgs):
220+
def method_with_args(self, x, y=20):
221+
return x * y
222+
223+
def arg_tracking_wrapper(original_method):
224+
def wrapped(*args, **kwargs):
225+
# Track that we're in the wrapper and call original
226+
result = original_method(*args, **kwargs)
227+
return result + 1000 # Add marker to show wrapping occurred
228+
229+
return wrapped
230+
231+
patch_leaf_subclasses(
232+
TestClassWithArgs, "method_with_args", arg_tracking_wrapper
233+
)
234+
235+
child_instance = ChildClass()
236+
result = child_instance.method_with_args(5, y=3)
237+
self.assertEqual(result, 1015) # (5 * 3) + 1000
238+
239+
def test_patch_leaf_subclasses_multiple_methods(self):
240+
"""Test patching multiple different methods"""
241+
call_tracker = []
242+
243+
def wrapper(original_method):
244+
def wrapped(*args, **kwargs):
245+
call_tracker.append(f"wrapped_{original_method.__name__}")
246+
return original_method(*args, **kwargs)
247+
248+
return wrapped
249+
250+
# Patch different methods
251+
patch_leaf_subclasses(self.BaseClass, "target_method", wrapper)
252+
patch_leaf_subclasses(self.BaseClass, "other_method", wrapper)
253+
254+
# Test both methods are patched
255+
leaf1_instance = self.LeafClass1()
256+
leaf1_instance.target_method()
257+
leaf1_instance.other_method()
258+
259+
self.assertIn("wrapped_target_method", call_tracker)
260+
self.assertIn("wrapped_other_method", call_tracker)
261+
262+
def test_complex_inheritance_hierarchy(self):
263+
"""Test with a more complex inheritance hierarchy"""
264+
265+
class A:
266+
def method(self):
267+
return "A"
268+
269+
class B(A):
270+
def method(self):
271+
return "B"
272+
273+
class C(A):
274+
def method(self):
275+
return "C"
276+
277+
class D(B):
278+
def method(self):
279+
return "D"
280+
281+
class E(B):
282+
pass # Inherits from B
283+
284+
class F(C):
285+
def method(self):
286+
return "F"
287+
288+
def wrapper(original_method):
289+
def wrapped(*args, **kwargs):
290+
result = original_method(*args, **kwargs)
291+
return f"wrapped_{result}"
292+
293+
return wrapped
294+
295+
patch_leaf_subclasses(A, "method", wrapper)
296+
297+
# Test leaf classes are patched
298+
d_instance = D()
299+
e_instance = E()
300+
f_instance = F()
301+
302+
self.assertEqual(d_instance.method(), "wrapped_D")
303+
self.assertEqual(e_instance.method(), "wrapped_B") # Inherits from B
304+
self.assertEqual(f_instance.method(), "wrapped_F")
305+
306+
# Test intermediate classes are not patched
307+
b_instance = B()
308+
c_instance = C()
309+
310+
self.assertEqual(b_instance.method(), "B")
311+
self.assertEqual(c_instance.method(), "C")
312+
313+
314+
if __name__ == "__main__":
315+
unittest.main()

0 commit comments

Comments
 (0)