22import re
33import traceback
44import uuid
5- from abc import ABC
5+ from abc import ABC , abstractmethod
66from typing import (
77 TYPE_CHECKING ,
88 Any ,
3232from ninja .signature import is_async
3333from ninja .utils import normalize_path
3434from pydantic import BaseModel as PydanticModel
35+ from pydantic import Field , validator
3536
3637from ninja_extra .constants import ROUTE_FUNCTION , THROTTLED_FUNCTION
3738from ninja_extra .exceptions import APIException , NotFound , PermissionDenied , bad_request
@@ -225,31 +226,53 @@ def create_response(
225226 )
226227
227228
228- class ModelControllerBase (ControllerBase ):
229- model : Model = None
230- allowed_routes : List [
231- "str"
232- ] = None # default = ['create', 'read', 'update', 'patch', 'delete', 'list']
229+ class ModelServiceBase (ABC ):
230+ @abstractmethod
231+ def get_one (self , pk : Any ) -> Any :
232+ pass
233+
234+ @abstractmethod
235+ def get_all (self ) -> QuerySet :
236+ pass
237+
238+ @abstractmethod
239+ def create (self , schema : PydanticModel , ** kwargs : Any ) -> Any :
240+ pass
241+
242+ @abstractmethod
243+ def update (self , instance : Model , schema : PydanticModel , ** kwargs : Any ) -> Any :
244+ pass
245+
246+ @abstractmethod
247+ def patch (self , instance : Model , schema : PydanticModel , ** kwargs : Any ) -> Any :
248+ pass
249+
250+ @abstractmethod
251+ def delete (self , instance : Model ) -> Any :
252+ pass
233253
234- model_schema : Type [PydanticModel ] = None
235- create_schema : Type [PydanticModel ] = None
236- update_schema : Type [PydanticModel ] = None
237254
238- pagination_class : Type [ PaginationBase ] = PageNumberPaginationExtra
239- pagination_response_schema : Type [PydanticModel ] = PaginatedResponseSchema
240- paginate_by : int = None
255+ class ModelService ( ModelServiceBase ):
256+ def __init__ ( self , model : Type [Model ]) -> None :
257+ self . model = model
241258
242- def get_queryset (self ) -> QuerySet :
259+ def get_one (self , pk : Any ) -> Any :
260+ obj = get_object_or_exception (
261+ klass = self .model , error_message = None , exception = NotFound , pk = pk
262+ )
263+ return obj
264+
265+ def get_all (self ) -> QuerySet :
243266 return self .model .objects .all ()
244267
245- def perform_create (self , schema : PydanticModel , ** kwargs : Any ) -> Any :
268+ def create (self , schema : PydanticModel , ** kwargs : Any ) -> Any :
246269 data = schema .dict (by_alias = True )
247270 data .update (kwargs )
248271
249272 try :
250273 instance = self .model ._default_manager .create (** data )
251274 return instance
252- except TypeError :
275+ except TypeError as tex :
253276 tb = traceback .format_exc ()
254277 msg = (
255278 "Got a `TypeError` when calling `%s.%s.create()`. "
@@ -267,27 +290,93 @@ def perform_create(self, schema: PydanticModel, **kwargs: Any) -> Any:
267290 tb ,
268291 )
269292 )
270- raise TypeError (msg )
293+ raise TypeError (msg ) from tex
271294
272- def perform_update (
273- self , instance : Model , schema : PydanticModel , ** kwargs : Any
274- ) -> Any :
295+ def update (self , instance : Model , schema : PydanticModel , ** kwargs : Any ) -> Any :
275296 data = schema .dict (exclude_none = True )
276297 data .update (kwargs )
277298 for attr , value in data .items ():
278299 setattr (instance , attr , value )
279300 instance .save ()
280301 return instance
281302
282- def perform_patch (
283- self , instance : Model , schema : PydanticModel , ** kwargs : Any
284- ) -> Any :
285- return self .perform_update (instance = instance , schema = schema , ** kwargs )
303+ def patch (self , instance : Model , schema : PydanticModel , ** kwargs : Any ) -> Any :
304+ return self .update (instance = instance , schema = schema , ** kwargs )
286305
287- def perform_delete (self , instance : Model ) -> Any :
306+ def delete (self , instance : Model ) -> Any :
288307 instance .delete ()
289308
290309
310+ class ModelConfigSchema (Tuple ):
311+ in_schema : Type [PydanticModel ]
312+ out_schema : Optional [Type [PydanticModel ]]
313+
314+ def get_out_schema (self ) -> Type [PydanticModel ]:
315+ if not self .out_schema :
316+ return self .in_schema
317+ return self .out_schema
318+
319+
320+ class ModelPagination (PydanticModel ):
321+ klass : Type [PaginationBase ] = PageNumberPaginationExtra
322+ paginate_by : Optional [int ] = None
323+ schema : Type [PydanticModel ] = PaginatedResponseSchema
324+
325+ @validator ("klass" )
326+ def validate_klass (cls , value : Any ) -> Any :
327+ if not issubclass (PaginationBase , value ):
328+ raise ValueError (f"{ value } is not of type `PaginationBase`" )
329+ return value
330+
331+ @validator (
332+ "schema" ,
333+ )
334+ def validate_schema (cls , value : Any ) -> Any :
335+ if not issubclass (PydanticModel , value ):
336+ raise ValueError (
337+ f"{ value } is not a valid type. Please use a generic pydantic model."
338+ )
339+ return value
340+
341+
342+ class ModelConfig (PydanticModel ):
343+ allowed_routes : List [str ] = Field (
344+ [
345+ "create" ,
346+ "read" ,
347+ "update" ,
348+ "patch" ,
349+ "delete" ,
350+ "list" ,
351+ ]
352+ )
353+ create_schema : ModelConfigSchema
354+ update_schema : ModelConfigSchema
355+ patch_schema : Optional [ModelConfigSchema ] = None
356+ retrieve_schema : Type [PydanticModel ]
357+ pagination : ModelPagination = Field (default = ModelPagination ())
358+ model : Type [Model ]
359+
360+ @validator ("allowed_routes" )
361+ def validate_allow_routes (cls , value : List [Any ]) -> Any :
362+ defaults = ["create" , "read" , "update" , "patch" , "delete" , "list" ]
363+ for item in value :
364+ if item not in defaults :
365+ raise ValueError (f"{ item } action is not recognized in { defaults } " )
366+ return value
367+
368+ @validator ("model" )
369+ def validate_model (cls , value : Any ) -> Any :
370+ if value and hasattr (value , "objects" ):
371+ return value
372+ raise ValueError (f"{ value } is not a valid Django model." )
373+
374+
375+ class ModelControllerBase (ControllerBase ):
376+ service : Optional [ModelService ] = None
377+ model_config : Optional [ModelConfig ] = None
378+
379+
291380class APIController :
292381 _PATH_PARAMETER_COMPONENT_RE = r"{(?:(?P<converter>[^>:]+):)?(?P<parameter>[^>]+)}"
293382
@@ -407,8 +496,13 @@ def __call__(self, cls: Type) -> Union[Type, Type["ControllerBase"]]:
407496 self ._controller_class = cls
408497
409498 if issubclass (cls , ModelControllerBase ):
410- builder = ModelControllerBuilder (cls , self )
411- builder .register_model_routes ()
499+ if cls .model_config :
500+ # if model_config is not provided, treat controller class as normal
501+ builder = ModelControllerBuilder (cls .model_config , self )
502+ builder .register_model_routes ()
503+ # We create a global service for handle CRUD Operations at class level
504+ # giving room for it to be changed at instance level through Dependency injection
505+ cls .service = ModelService (cls .model_config .model )
412506
413507 bases = inspect .getmro (cls )
414508 for base_cls in reversed (bases ):
0 commit comments