1- from typing import Any
1+ from typing import Any , get_args , get_origin
22
33from hamcrest import anything
44from hamcrest .core .base_matcher import BaseMatcher
@@ -14,8 +14,21 @@ def __new__(cls, name, bases, namespace, **_kwargs):
1414 return super ().__new__ (cls , name , bases , namespace )
1515
1616 domain_class = namespace .get ("__domain_class__" )
17+
1718 if domain_class is None :
18- msg = f"{ name } must define __domain_class__"
19+ orig_bases = namespace .get ("__orig_bases__" , [])
20+ for orig in orig_bases :
21+ origin = get_origin (orig )
22+ args = get_args (orig )
23+ if origin is BaseAutoMatcher and args :
24+ inferred_type = args [0 ]
25+ if hasattr (inferred_type , "__annotations__" ):
26+ domain_class = inferred_type
27+ namespace ["__domain_class__" ] = domain_class
28+ break
29+
30+ if domain_class is None or not hasattr (domain_class , "__annotations__" ):
31+ msg = f"{ name } must define or infer __domain_class__ with annotations"
1932 raise TypeError (msg )
2033
2134 for field_name in domain_class .__annotations__ :
@@ -25,8 +38,27 @@ def __new__(cls, name, bases, namespace, **_kwargs):
2538 return super ().__new__ (cls , name , bases , namespace )
2639
2740
28- class BaseAutoMatcher (BaseMatcher , metaclass = AutoMatcherMeta ):
29- __domain_class__ = None # must be overridden
41+ class BaseAutoMatcher [T ](BaseMatcher , metaclass = AutoMatcherMeta ):
42+ """Create matchers for classes. Use like so:
43+
44+ ```python
45+ from hamcrest import assert_that, equal_to
46+
47+ class EligibilityStatus(BaseModel):
48+ status: str
49+ reason: str | None = None
50+
51+ class EligibilityStatusMatcher(BaseAutoMatcher[EligibilityStatus]): ...
52+ def is_eligibility_status() -> Matcher[EligibilityStatus]: return EligibilityStatusMatcher()
53+
54+ assert_that(EligibilityStatus(status="ACTIVE"), is_eligibility_status().with_status("ACTIVE").and_reason(None))
55+ ```
56+
57+ Works only for classes with `__annotations__`; manually annotated classes, dataclasses.dataclass and
58+ pydantic.BaseModel instances.
59+ """
60+
61+ __domain_class__ = None # Will be inferred when subclassed generically
3062
3163 def __init__ (self ):
3264 super ().__init__ ()
@@ -37,24 +69,24 @@ def describe_to(self, description: Description) -> None:
3769 attr_name = f"{ field_name } _" if field_name in {"id" , "type" } else field_name
3870 self .append_matcher_description (getattr (self , attr_name ), field_name , description )
3971
40- def _matches (self , item ) -> bool :
72+ def _matches (self , item : T ) -> bool :
4173 return all (
4274 getattr (self , f"{ field } _" if field in {"id" , "type" } else field ).matches (getattr (item , field ))
4375 for field in self .__domain_class__ .__annotations__
4476 )
4577
46- def describe_mismatch (self , item , mismatch_description : Description ) -> None :
78+ def describe_mismatch (self , item : T , mismatch_description : Description ) -> None :
4779 mismatch_description .append_text (f"was { self .__domain_class__ .__name__ } with" )
4880 for field_name in self .__domain_class__ .__annotations__ :
49- value = getattr (item , field_name )
5081 matcher = getattr (self , f"{ field_name } _" if field_name in {"id" , "type" } else field_name )
82+ value = getattr (item , field_name )
5183 self .describe_field_mismatch (matcher , field_name , value , mismatch_description )
5284
53- def describe_match (self , item , match_description : Description ) -> None :
85+ def describe_match (self , item : T , match_description : Description ) -> None :
5486 match_description .append_text (f"was { self .__domain_class__ .__name__ } with" )
5587 for field_name in self .__domain_class__ .__annotations__ :
56- value = getattr (item , field_name )
5788 matcher = getattr (self , f"{ field_name } _" if field_name in {"id" , "type" } else field_name )
89+ value = getattr (item , field_name )
5890 self .describe_field_match (matcher , field_name , value , match_description )
5991
6092 def __getattr__ (self , name : str ):
@@ -74,8 +106,8 @@ def setter(value):
74106 def __dir__ (self ):
75107 dynamic_methods = []
76108 for field_name in self .__domain_class__ .__annotations__ :
77- method_base = field_name .rstrip ("_" ) if field_name in {"id" , "type" } else field_name
78- dynamic_methods .extend ([f"with_{ method_base } " , f"and_{ method_base } " ])
109+ base = field_name .rstrip ("_" ) if field_name in {"id" , "type" } else field_name
110+ dynamic_methods .extend ([f"with_{ base } " , f"and_{ base } " ])
79111 return list (super ().__dir__ ()) + dynamic_methods
80112
81113 @staticmethod
0 commit comments