|
1 | 1 | import inspect |
2 | 2 | import re |
3 | | -import traceback |
4 | 3 | import uuid |
5 | | -from abc import ABC, abstractmethod |
| 4 | +from abc import ABC |
6 | 5 | from typing import ( |
7 | 6 | TYPE_CHECKING, |
8 | 7 | Any, |
|
27 | 26 | from injector import inject, is_decorated_with_inject |
28 | 27 | from ninja import NinjaAPI, Router |
29 | 28 | from ninja.constants import NOT_SET |
30 | | -from ninja.pagination import PaginationBase |
31 | 29 | from ninja.security.base import AuthBase |
32 | 30 | from ninja.signature import is_async |
33 | 31 | from ninja.utils import normalize_path |
34 | | -from pydantic import BaseModel as PydanticModel |
35 | | -from pydantic import Field, validator |
36 | 32 |
|
37 | 33 | from ninja_extra.constants import ROUTE_FUNCTION, THROTTLED_FUNCTION |
38 | 34 | from ninja_extra.exceptions import APIException, NotFound, PermissionDenied, bad_request |
|
47 | 43 | ) |
48 | 44 | from ninja_extra.types import PermissionType |
49 | 45 |
|
50 | | -from ..pagination import PageNumberPaginationExtra, PaginatedResponseSchema |
51 | | -from .model_controller_builder import ModelControllerBuilder |
| 46 | +from .model import ModelConfig, ModelControllerBuilder, ModelService |
52 | 47 | from .registry import ControllerRegistry |
53 | 48 | from .response import Detail, Id, Ok |
54 | 49 | from .route.route_functions import AsyncRouteFunction, RouteFunction |
@@ -226,154 +221,8 @@ def create_response( |
226 | 221 | ) |
227 | 222 |
|
228 | 223 |
|
229 | | -class ModelServiceBase(ABC): |
230 | | - @abstractmethod |
231 | | - def get_one(self, pk: Any) -> Any: |
232 | | - pass |
233 | | - |
234 | | - @abstractmethod |
235 | | - def get_all(self) -> QuerySet: |
236 | | - pass |
237 | | - |
238 | | - @abstractmethod |
239 | | - def create(self, schema: PydanticModel, **kwargs: Any) -> Any: |
240 | | - pass |
241 | | - |
242 | | - @abstractmethod |
243 | | - def update(self, instance: Model, schema: PydanticModel, **kwargs: Any) -> Any: |
244 | | - pass |
245 | | - |
246 | | - @abstractmethod |
247 | | - def patch(self, instance: Model, schema: PydanticModel, **kwargs: Any) -> Any: |
248 | | - pass |
249 | | - |
250 | | - @abstractmethod |
251 | | - def delete(self, instance: Model) -> Any: |
252 | | - pass |
253 | | - |
254 | | - |
255 | | -class ModelService(ModelServiceBase): |
256 | | - def __init__(self, model: Type[Model]) -> None: |
257 | | - self.model = model |
258 | | - |
259 | | - def get_one(self, pk: Any) -> Any: |
260 | | - obj = get_object_or_exception( |
261 | | - klass=self.model, error_message=None, exception=NotFound, pk=pk |
262 | | - ) |
263 | | - return obj |
264 | | - |
265 | | - def get_all(self) -> QuerySet: |
266 | | - return self.model.objects.all() |
267 | | - |
268 | | - def create(self, schema: PydanticModel, **kwargs: Any) -> Any: |
269 | | - data = schema.dict(by_alias=True) |
270 | | - data.update(kwargs) |
271 | | - |
272 | | - try: |
273 | | - instance = self.model._default_manager.create(**data) |
274 | | - return instance |
275 | | - except TypeError as tex: |
276 | | - tb = traceback.format_exc() |
277 | | - msg = ( |
278 | | - "Got a `TypeError` when calling `%s.%s.create()`. " |
279 | | - "This may be because you have a writable field on the " |
280 | | - "serializer class that is not a valid argument to " |
281 | | - "`%s.%s.create()`. You may need to make the field " |
282 | | - "read-only, or override the %s.create() method to handle " |
283 | | - "this correctly.\nOriginal exception was:\n %s" |
284 | | - % ( |
285 | | - self.model.__name__, |
286 | | - self.model._default_manager.name, |
287 | | - self.model.__name__, |
288 | | - self.model._default_manager.name, |
289 | | - self.__class__.__name__, |
290 | | - tb, |
291 | | - ) |
292 | | - ) |
293 | | - raise TypeError(msg) from tex |
294 | | - |
295 | | - def update(self, instance: Model, schema: PydanticModel, **kwargs: Any) -> Any: |
296 | | - data = schema.dict(exclude_none=True) |
297 | | - data.update(kwargs) |
298 | | - for attr, value in data.items(): |
299 | | - setattr(instance, attr, value) |
300 | | - instance.save() |
301 | | - return instance |
302 | | - |
303 | | - def patch(self, instance: Model, schema: PydanticModel, **kwargs: Any) -> Any: |
304 | | - return self.update(instance=instance, schema=schema, **kwargs) |
305 | | - |
306 | | - def delete(self, instance: Model) -> Any: |
307 | | - instance.delete() |
308 | | - |
309 | | - |
310 | | -class ModelConfigSchema(Tuple): |
311 | | - in_schema: Type[PydanticModel] |
312 | | - out_schema: Optional[Type[PydanticModel]] |
313 | | - |
314 | | - def get_out_schema(self) -> Type[PydanticModel]: |
315 | | - if not self.out_schema: |
316 | | - return self.in_schema |
317 | | - return self.out_schema |
318 | | - |
319 | | - |
320 | | -class ModelPagination(PydanticModel): |
321 | | - klass: Type[PaginationBase] = PageNumberPaginationExtra |
322 | | - paginate_by: Optional[int] = None |
323 | | - schema: Type[PydanticModel] = PaginatedResponseSchema |
324 | | - |
325 | | - @validator("klass") |
326 | | - def validate_klass(cls, value: Any) -> Any: |
327 | | - if not issubclass(PaginationBase, value): |
328 | | - raise ValueError(f"{value} is not of type `PaginationBase`") |
329 | | - return value |
330 | | - |
331 | | - @validator( |
332 | | - "schema", |
333 | | - ) |
334 | | - def validate_schema(cls, value: Any) -> Any: |
335 | | - if not issubclass(PydanticModel, value): |
336 | | - raise ValueError( |
337 | | - f"{value} is not a valid type. Please use a generic pydantic model." |
338 | | - ) |
339 | | - return value |
340 | | - |
341 | | - |
342 | | -class ModelConfig(PydanticModel): |
343 | | - allowed_routes: List[str] = Field( |
344 | | - [ |
345 | | - "create", |
346 | | - "read", |
347 | | - "update", |
348 | | - "patch", |
349 | | - "delete", |
350 | | - "list", |
351 | | - ] |
352 | | - ) |
353 | | - create_schema: ModelConfigSchema |
354 | | - update_schema: ModelConfigSchema |
355 | | - patch_schema: Optional[ModelConfigSchema] = None |
356 | | - retrieve_schema: Type[PydanticModel] |
357 | | - pagination: ModelPagination = Field(default=ModelPagination()) |
358 | | - model: Type[Model] |
359 | | - |
360 | | - @validator("allowed_routes") |
361 | | - def validate_allow_routes(cls, value: List[Any]) -> Any: |
362 | | - defaults = ["create", "read", "update", "patch", "delete", "list"] |
363 | | - for item in value: |
364 | | - if item not in defaults: |
365 | | - raise ValueError(f"{item} action is not recognized in {defaults}") |
366 | | - return value |
367 | | - |
368 | | - @validator("model") |
369 | | - def validate_model(cls, value: Any) -> Any: |
370 | | - if value and hasattr(value, "objects"): |
371 | | - return value |
372 | | - raise ValueError(f"{value} is not a valid Django model.") |
373 | | - |
374 | | - |
375 | 224 | class ModelControllerBase(ControllerBase): |
376 | | - service: Optional[ModelService] = None |
| 225 | + service: ModelService |
377 | 226 | model_config: Optional[ModelConfig] = None |
378 | 227 |
|
379 | 228 |
|
@@ -498,7 +347,7 @@ def __call__(self, cls: Type) -> Union[Type, Type["ControllerBase"]]: |
498 | 347 | if issubclass(cls, ModelControllerBase): |
499 | 348 | if cls.model_config: |
500 | 349 | # if model_config is not provided, treat controller class as normal |
501 | | - builder = ModelControllerBuilder(cls.model_config, self) |
| 350 | + builder = ModelControllerBuilder(cls, self) |
502 | 351 | builder.register_model_routes() |
503 | 352 | # We create a global service for handle CRUD Operations at class level |
504 | 353 | # giving room for it to be changed at instance level through Dependency injection |
|
0 commit comments