33import posixpath
44import warnings
55from abc import ABC
6+ from typing import ItemsView , cast
67
78from typing_extensions import (
89 TYPE_CHECKING ,
@@ -73,6 +74,8 @@ def __init__(self, ctx: Context, path: str, /, **attributes):
7374class Resource (Protocol ):
7475 def __getitem__ (self , key : Hashable ) -> Any : ...
7576
77+ def items (self ) -> ItemsView : ...
78+
7679
7780class _Resource (dict , Resource ):
7881 def __init__ (self , ctx : Context , path : str , ** attributes ):
@@ -92,6 +95,10 @@ def update(self, **attributes): # type: ignore[reportIncompatibleMethodOverride
9295T = TypeVar ("T" , bound = Resource )
9396
9497
98+ class ResourceFactory (Protocol ):
99+ def __call__ (self , ctx : Context , path : str , ** attributes ) -> Resource : ...
100+
101+
95102class ResourceSequence (Protocol [T ]):
96103 @overload
97104 def __getitem__ (self , index : SupportsIndex , / ) -> T : ...
@@ -109,10 +116,17 @@ def __repr__(self) -> str: ...
109116
110117
111118class _ResourceSequence (Sequence [T ], ResourceSequence [T ]):
112- def __init__ (self , ctx : Context , path : str , * , uid : str = "guid" ):
119+ def __init__ (
120+ self ,
121+ ctx : Context ,
122+ path : str ,
123+ factory : ResourceFactory = _Resource ,
124+ uid : str = "guid" ,
125+ ):
113126 self ._ctx = ctx
114127 self ._path = path
115128 self ._uid = uid
129+ self ._factory = factory
116130
117131 def __getitem__ (self , index ):
118132 return list (self .fetch ())[index ]
@@ -129,32 +143,32 @@ def __str__(self) -> str:
129143 def __repr__ (self ) -> str :
130144 return repr (self .fetch ())
131145
132- def create (self , ** attributes : Any ) -> Any :
146+ def create (self , ** attributes : Any ) -> T :
133147 response = self ._ctx .client .post (self ._path , json = attributes )
134148 result = response .json ()
135149 uid = result [self ._uid ]
136150 path = posixpath .join (self ._path , uid )
137- return _Resource ( self ._ctx , path , ** result )
151+ return cast ( T , self ._factory ( self . _ctx , path , ** result ) )
138152
139- def fetch (self , ** conditions ) -> Iterable [Any ]:
153+ def fetch (self , ** conditions ) -> Iterable [T ]:
140154 response = self ._ctx .client .get (self ._path , params = conditions )
141155 results = response .json ()
142- resources = []
156+ resources : List [ T ] = []
143157 for result in results :
144158 uid = result [self ._uid ]
145159 path = posixpath .join (self ._path , uid )
146- resource = _Resource ( self ._ctx , path , ** result )
160+ resource = cast ( T , self ._factory ( self . _ctx , path , ** result ) )
147161 resources .append (resource )
148162
149163 return resources
150164
151- def find (self , * args : str ) -> Any :
165+ def find (self , * args : str ) -> T :
152166 path = posixpath .join (self ._path , * args )
153167 response = self ._ctx .client .get (path )
154168 result = response .json ()
155- return _Resource ( self ._ctx , path , ** result )
169+ return cast ( T , self ._factory ( self . _ctx , path , ** result ) )
156170
157- def find_by (self , ** conditions ) -> Any | None :
171+ def find_by (self , ** conditions ) -> T | None :
158172 """
159173 Find the first record matching the specified conditions.
160174
@@ -169,19 +183,19 @@ def find_by(self, **conditions) -> Any | None:
169183 Optional[T]
170184 The first record matching the conditions, or `None` if no match is found.
171185 """
172- collection = self .fetch (** conditions )
186+ collection : Iterable [ T ] = self .fetch (** conditions )
173187 return next ((v for v in collection if v .items () >= conditions .items ()), None )
174188
175189
176- class _PaginatedResourceSequence (_ResourceSequence ):
177- def fetch (self , ** conditions ):
190+ class _PaginatedResourceSequence (_ResourceSequence [ T ] ):
191+ def fetch (self , ** conditions ) -> Iterator [ T ] :
178192 paginator = Paginator (self ._ctx , self ._path , dict (** conditions ))
179193 for page in paginator .fetch_pages ():
180194 resources = []
181195 results = page .results
182196 for result in results :
183197 uid = result [self ._uid ]
184198 path = posixpath .join (self ._path , uid )
185- resource = _Resource ( self ._ctx , path , ** result )
199+ resource = cast ( T , self ._factory ( self . _ctx , path , ** result ) )
186200 resources .append (resource )
187201 yield from resources
0 commit comments