Skip to content

Commit 0bb12a9

Browse files
committed
model controller setup still in progress
1 parent 8358737 commit 0bb12a9

File tree

4 files changed

+179
-89
lines changed

4 files changed

+179
-89
lines changed

ninja_extra/controllers/base.py

Lines changed: 120 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import re
33
import traceback
44
import uuid
5-
from abc import ABC
5+
from abc import ABC, abstractmethod
66
from typing import (
77
TYPE_CHECKING,
88
Any,
@@ -32,6 +32,7 @@
3232
from ninja.signature import is_async
3333
from ninja.utils import normalize_path
3434
from pydantic import BaseModel as PydanticModel
35+
from pydantic import Field, validator
3536

3637
from ninja_extra.constants import ROUTE_FUNCTION, THROTTLED_FUNCTION
3738
from ninja_extra.exceptions import APIException, NotFound, PermissionDenied, bad_request
@@ -225,31 +226,53 @@ def create_response(
225226
)
226227

227228

228-
class ModelControllerBase(ControllerBase):
229-
model: Model = None
230-
allowed_routes: List[
231-
"str"
232-
] = None # default = ['create', 'read', 'update', 'patch', 'delete', 'list']
229+
class ModelServiceBase(ABC):
230+
@abstractmethod
231+
def get_one(self, pk: Any) -> Any:
232+
pass
233+
234+
@abstractmethod
235+
def get_all(self) -> QuerySet:
236+
pass
237+
238+
@abstractmethod
239+
def create(self, schema: PydanticModel, **kwargs: Any) -> Any:
240+
pass
241+
242+
@abstractmethod
243+
def update(self, instance: Model, schema: PydanticModel, **kwargs: Any) -> Any:
244+
pass
245+
246+
@abstractmethod
247+
def patch(self, instance: Model, schema: PydanticModel, **kwargs: Any) -> Any:
248+
pass
249+
250+
@abstractmethod
251+
def delete(self, instance: Model) -> Any:
252+
pass
233253

234-
model_schema: Type[PydanticModel] = None
235-
create_schema: Type[PydanticModel] = None
236-
update_schema: Type[PydanticModel] = None
237254

238-
pagination_class: Type[PaginationBase] = PageNumberPaginationExtra
239-
pagination_response_schema: Type[PydanticModel] = PaginatedResponseSchema
240-
paginate_by: int = None
255+
class ModelService(ModelServiceBase):
256+
def __init__(self, model: Type[Model]) -> None:
257+
self.model = model
241258

242-
def get_queryset(self) -> QuerySet:
259+
def get_one(self, pk: Any) -> Any:
260+
obj = get_object_or_exception(
261+
klass=self.model, error_message=None, exception=NotFound, pk=pk
262+
)
263+
return obj
264+
265+
def get_all(self) -> QuerySet:
243266
return self.model.objects.all()
244267

245-
def perform_create(self, schema: PydanticModel, **kwargs: Any) -> Any:
268+
def create(self, schema: PydanticModel, **kwargs: Any) -> Any:
246269
data = schema.dict(by_alias=True)
247270
data.update(kwargs)
248271

249272
try:
250273
instance = self.model._default_manager.create(**data)
251274
return instance
252-
except TypeError:
275+
except TypeError as tex:
253276
tb = traceback.format_exc()
254277
msg = (
255278
"Got a `TypeError` when calling `%s.%s.create()`. "
@@ -267,27 +290,93 @@ def perform_create(self, schema: PydanticModel, **kwargs: Any) -> Any:
267290
tb,
268291
)
269292
)
270-
raise TypeError(msg)
293+
raise TypeError(msg) from tex
271294

272-
def perform_update(
273-
self, instance: Model, schema: PydanticModel, **kwargs: Any
274-
) -> Any:
295+
def update(self, instance: Model, schema: PydanticModel, **kwargs: Any) -> Any:
275296
data = schema.dict(exclude_none=True)
276297
data.update(kwargs)
277298
for attr, value in data.items():
278299
setattr(instance, attr, value)
279300
instance.save()
280301
return instance
281302

282-
def perform_patch(
283-
self, instance: Model, schema: PydanticModel, **kwargs: Any
284-
) -> Any:
285-
return self.perform_update(instance=instance, schema=schema, **kwargs)
303+
def patch(self, instance: Model, schema: PydanticModel, **kwargs: Any) -> Any:
304+
return self.update(instance=instance, schema=schema, **kwargs)
286305

287-
def perform_delete(self, instance: Model) -> Any:
306+
def delete(self, instance: Model) -> Any:
288307
instance.delete()
289308

290309

310+
class ModelConfigSchema(Tuple):
311+
in_schema: Type[PydanticModel]
312+
out_schema: Optional[Type[PydanticModel]]
313+
314+
def get_out_schema(self) -> Type[PydanticModel]:
315+
if not self.out_schema:
316+
return self.in_schema
317+
return self.out_schema
318+
319+
320+
class ModelPagination(PydanticModel):
321+
klass: Type[PaginationBase] = PageNumberPaginationExtra
322+
paginate_by: Optional[int] = None
323+
schema: Type[PydanticModel] = PaginatedResponseSchema
324+
325+
@validator("klass")
326+
def validate_klass(cls, value: Any) -> Any:
327+
if not issubclass(PaginationBase, value):
328+
raise ValueError(f"{value} is not of type `PaginationBase`")
329+
return value
330+
331+
@validator(
332+
"schema",
333+
)
334+
def validate_schema(cls, value: Any) -> Any:
335+
if not issubclass(PydanticModel, value):
336+
raise ValueError(
337+
f"{value} is not a valid type. Please use a generic pydantic model."
338+
)
339+
return value
340+
341+
342+
class ModelConfig(PydanticModel):
343+
allowed_routes: List[str] = Field(
344+
[
345+
"create",
346+
"read",
347+
"update",
348+
"patch",
349+
"delete",
350+
"list",
351+
]
352+
)
353+
create_schema: ModelConfigSchema
354+
update_schema: ModelConfigSchema
355+
patch_schema: Optional[ModelConfigSchema] = None
356+
retrieve_schema: Type[PydanticModel]
357+
pagination: ModelPagination = Field(default=ModelPagination())
358+
model: Type[Model]
359+
360+
@validator("allowed_routes")
361+
def validate_allow_routes(cls, value: List[Any]) -> Any:
362+
defaults = ["create", "read", "update", "patch", "delete", "list"]
363+
for item in value:
364+
if item not in defaults:
365+
raise ValueError(f"{item} action is not recognized in {defaults}")
366+
return value
367+
368+
@validator("model")
369+
def validate_model(cls, value: Any) -> Any:
370+
if value and hasattr(value, "objects"):
371+
return value
372+
raise ValueError(f"{value} is not a valid Django model.")
373+
374+
375+
class ModelControllerBase(ControllerBase):
376+
service: Optional[ModelService] = None
377+
model_config: Optional[ModelConfig] = None
378+
379+
291380
class APIController:
292381
_PATH_PARAMETER_COMPONENT_RE = r"{(?:(?P<converter>[^>:]+):)?(?P<parameter>[^>]+)}"
293382

@@ -407,8 +496,13 @@ def __call__(self, cls: Type) -> Union[Type, Type["ControllerBase"]]:
407496
self._controller_class = cls
408497

409498
if issubclass(cls, ModelControllerBase):
410-
builder = ModelControllerBuilder(cls, self)
411-
builder.register_model_routes()
499+
if cls.model_config:
500+
# if model_config is not provided, treat controller class as normal
501+
builder = ModelControllerBuilder(cls.model_config, self)
502+
builder.register_model_routes()
503+
# We create a global service for handle CRUD Operations at class level
504+
# giving room for it to be changed at instance level through Dependency injection
505+
cls.service = ModelService(cls.model_config.model)
412506

413507
bases = inspect.getmro(cls)
414508
for base_cls in reversed(bases):

0 commit comments

Comments
 (0)