1- def patch_abc ( abstract_base_class , method_name , w ):
1+ def patch_leaf_subclasses ( base_class , method_name , wrapper ):
22 """
3- Patches a method on leaf subclasses of an abstract base class.
3+ Patches a method on leaf subclasses of a base class.
4+
5+ Args:
6+ base_class: The base class whose leaf subclasses will be patched
7+ method_name: Name of the method to patch
8+ wrapper: Function that wraps the original method
49
510 """
6- all_subclasses = recursively_get_all_subclasses ( abstract_base_class )
7- leaf_subclasses = get_leaf_subclasses (all_subclasses )
11+ all_subclasses = _get_all_subclasses ( base_class )
12+ leaf_subclasses = _get_leaf_subclasses (all_subclasses )
813
914 for subclass in leaf_subclasses :
1015 # Patch if the subclass has the method (either defined or inherited)
1116 # and it's actually callable
1217 if hasattr (subclass , method_name ) and callable (getattr (subclass , method_name )):
1318 old_method = getattr (subclass , method_name )
14- setattr (subclass , method_name , w (old_method ))
19+ setattr (subclass , method_name , wrapper (old_method ))
1520
1621 # This implementation does not work if the instrumented class is imported after the instrumentor runs.
1722 # However, that case can be handled by querying the gc module for all existing classes; this capability can be added
1823 # in a follow-up change.
1924
2025
21- def get_leaf_subclasses (all_subclasses ):
26+ def _get_leaf_subclasses (all_subclasses ):
2227 """
2328 Returns only the leaf classes (classes with no subclasses) from a set of classes.
29+
30+ Args:
31+ all_subclasses: Set of classes to filter
32+
33+ Returns:
34+ set: Classes that have no subclasses within the provided set
2435 """
2536 leaf_classes = set ()
2637 for cls in all_subclasses :
@@ -35,10 +46,18 @@ def get_leaf_subclasses(all_subclasses):
3546 return leaf_classes
3647
3748
49+ def _get_all_subclasses (cls ):
50+ """
51+ Gets all subclasses of a given class.
52+
53+ Args:
54+ cls: The base class to find subclasses for
3855
39- def recursively_get_all_subclasses (cls ):
40- out = set ()
56+ Returns:
57+ set: All subclasses (direct and indirect) of the given class
58+ """
59+ subclasses = set ()
4160 for subclass in cls .__subclasses__ ():
42- out .add (subclass )
43- out .update (recursively_get_all_subclasses (subclass ))
44- return out
61+ subclasses .add (subclass )
62+ subclasses .update (_get_all_subclasses (subclass ))
63+ return subclasses
0 commit comments