|
| 1 | +from drf_spectacular.extensions import OpenApiFilterExtension |
| 2 | +from drf_spectacular.plumbing import build_parameter_type |
| 3 | +from drf_spectacular.utils import OpenApiParameter |
| 4 | + |
| 5 | +from ansible_base.rest_filters.rest_framework.field_lookup_backend import FieldLookupBackend |
| 6 | +from ansible_base.rest_filters.rest_framework.order_backend import OrderByBackend |
| 7 | +from ansible_base.rest_filters.rest_framework.type_filter_backend import TypeFilterBackend |
| 8 | + |
| 9 | + |
| 10 | +class FieldLookupBackendExtension(OpenApiFilterExtension): |
| 11 | + target_class = FieldLookupBackend |
| 12 | + |
| 13 | + def get_schema_operation_parameters(self, auto_schema, *args, **kwargs): |
| 14 | + """ |
| 15 | + Generate OpenAPI parameters for FieldLookupBackend. |
| 16 | +
|
| 17 | + This filter backend supports field lookups on any model field using Django's |
| 18 | + field lookup syntax (e.g., field__exact, field__contains, field__gt, etc.). |
| 19 | + Since the actual fields depend on the model, we provide generic examples. |
| 20 | + """ |
| 21 | + parameters = [] |
| 22 | + |
| 23 | + # Add model-based parameters if model is available |
| 24 | + if self._has_model_queryset(auto_schema): |
| 25 | + model = auto_schema.view.queryset.model |
| 26 | + model_fields = self._get_relevant_model_fields(model) |
| 27 | + parameters.extend(self._create_model_field_parameters(model, model_fields)) |
| 28 | + |
| 29 | + # Add RBAC parameter |
| 30 | + parameters.append(self._create_role_level_parameter()) |
| 31 | + |
| 32 | + return parameters |
| 33 | + |
| 34 | + def _has_model_queryset(self, auto_schema): |
| 35 | + """Check if the view has a model queryset.""" |
| 36 | + return hasattr(auto_schema.view, 'queryset') and auto_schema.view.queryset is not None |
| 37 | + |
| 38 | + def _get_relevant_model_fields(self, model): |
| 39 | + """Get relevant model fields, excluding complex relationships.""" |
| 40 | + model_fields = [] |
| 41 | + for field in model._meta.get_fields(): |
| 42 | + if self._is_simple_field(field): |
| 43 | + model_fields.append(field.name) |
| 44 | + return model_fields |
| 45 | + |
| 46 | + def _is_simple_field(self, field): |
| 47 | + """Check if field is a simple field (not many-to-many or one-to-many).""" |
| 48 | + return hasattr(field, 'name') and not field.many_to_many and not (hasattr(field, 'one_to_many') and field.one_to_many) |
| 49 | + |
| 50 | + def _create_model_field_parameters(self, model, field_names): |
| 51 | + """Create parameters for all model fields.""" |
| 52 | + parameters = [] |
| 53 | + for field_name in field_names: |
| 54 | + parameters.extend(self._create_field_parameters(model, field_name)) |
| 55 | + return parameters |
| 56 | + |
| 57 | + def _create_field_parameters(self, model, field_name): |
| 58 | + """Create all parameter variations for a single field.""" |
| 59 | + parameters = [] |
| 60 | + |
| 61 | + # Basic exact match parameter |
| 62 | + parameters.append(self._create_parameter(field_name, f'Filter by {field_name} (exact match)')) |
| 63 | + |
| 64 | + # Add field-type specific parameters |
| 65 | + field_obj = self._get_field_by_name(model, field_name) |
| 66 | + if field_obj: |
| 67 | + if self._is_string_field(field_obj): |
| 68 | + parameters.append(self._create_parameter(f'{field_name}__icontains', f'Filter by {field_name} (case-insensitive partial match)')) |
| 69 | + |
| 70 | + if self._is_numeric_or_date_field(field_obj): |
| 71 | + parameters.extend(self._create_comparison_parameters(field_name)) |
| 72 | + |
| 73 | + return parameters |
| 74 | + |
| 75 | + def _get_field_by_name(self, model, field_name): |
| 76 | + """Get field object by name from model.""" |
| 77 | + for field in model._meta.get_fields(): |
| 78 | + if hasattr(field, 'name') and field.name == field_name: |
| 79 | + return field |
| 80 | + return None |
| 81 | + |
| 82 | + def _is_string_field(self, field): |
| 83 | + """Check if field is a string-based field.""" |
| 84 | + from django.db import models |
| 85 | + |
| 86 | + return isinstance(field, (models.CharField, models.TextField)) |
| 87 | + |
| 88 | + def _is_numeric_or_date_field(self, field): |
| 89 | + """Check if field is numeric or date-based.""" |
| 90 | + from django.db import models |
| 91 | + |
| 92 | + numeric_date_types = (models.IntegerField, models.DateTimeField, models.DateField, models.DecimalField, models.FloatField) |
| 93 | + return isinstance(field, numeric_date_types) |
| 94 | + |
| 95 | + def _create_comparison_parameters(self, field_name): |
| 96 | + """Create comparison parameters (gt, gte, lt, lte) for a field.""" |
| 97 | + parameters = [] |
| 98 | + for lookup in ['gt', 'gte', 'lt', 'lte']: |
| 99 | + parameters.append(self._create_parameter(f'{field_name}__{lookup}', f'Filter by {field_name} ({lookup})')) |
| 100 | + return parameters |
| 101 | + |
| 102 | + def _create_parameter(self, name, description): |
| 103 | + """Create a single OpenAPI parameter.""" |
| 104 | + return build_parameter_type( |
| 105 | + name=name, |
| 106 | + schema={'type': 'string'}, |
| 107 | + location=OpenApiParameter.QUERY, |
| 108 | + required=False, |
| 109 | + description=description, |
| 110 | + ) |
| 111 | + |
| 112 | + def _create_role_level_parameter(self): |
| 113 | + """Create the role_level parameter for RBAC.""" |
| 114 | + return self._create_parameter('role_level', 'Filter by role level for RBAC') |
| 115 | + |
| 116 | + |
| 117 | +class TypeFilterBackendExtension(OpenApiFilterExtension): |
| 118 | + target_class = TypeFilterBackend |
| 119 | + |
| 120 | + def get_schema_operation_parameters(self, auto_schema, *args, **kwargs): |
| 121 | + """ |
| 122 | + Generate OpenAPI parameters for TypeFilterBackend. |
| 123 | +
|
| 124 | + This filter backend supports filtering by object type. |
| 125 | + """ |
| 126 | + return [ |
| 127 | + build_parameter_type( |
| 128 | + name='type', |
| 129 | + schema={'type': 'string'}, |
| 130 | + location=OpenApiParameter.QUERY, |
| 131 | + required=False, |
| 132 | + description='Filter by object type. Supports comma-separated values for multiple types.', |
| 133 | + ) |
| 134 | + ] |
| 135 | + |
| 136 | + |
| 137 | +class OrderByBackendExtension(OpenApiFilterExtension): |
| 138 | + target_class = OrderByBackend |
| 139 | + |
| 140 | + def get_schema_operation_parameters(self, auto_schema, *args, **kwargs): |
| 141 | + """ |
| 142 | + Generate OpenAPI parameters for OrderByBackend. |
| 143 | +
|
| 144 | + This filter backend supports ordering results by field names. |
| 145 | + """ |
| 146 | + parameters = [] |
| 147 | + |
| 148 | + # Add the ordering parameters |
| 149 | + for param_name in ['order', 'order_by']: |
| 150 | + parameters.append( |
| 151 | + build_parameter_type( |
| 152 | + name=param_name, |
| 153 | + schema={'type': 'string'}, |
| 154 | + location=OpenApiParameter.QUERY, |
| 155 | + required=False, |
| 156 | + description='Order results by field name. Prefix with \'-\' for descending order. Supports comma-separated values for multiple fields.', |
| 157 | + ) |
| 158 | + ) |
| 159 | + |
| 160 | + return parameters |
0 commit comments