Skip to content

Commit 4709dc8

Browse files
[API] Search improvements (#11094)
* Improve prefetching * Cache user groups for permission check * Use a GET request to execute search - Prevent forced prefetch - Reduce execution time significantly * Fix group caching * Improve StockItemSerializer - Select related for pricing_data rather than prefetch * Add benchmarking for search endpoint * Adjust prefetch * Ensure no errors returned * Fix prefetch * Fix more prefetch issues * Remove debug print * Fix for performance testing * Data is already returned as dict * Test fix * Extract model types better
1 parent 2457197 commit 4709dc8

File tree

9 files changed

+125
-32
lines changed

9 files changed

+125
-32
lines changed

src/backend/InvenTree/InvenTree/api.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from drf_spectacular.utils import OpenApiParameter, OpenApiResponse, extend_schema
1818
from rest_framework import serializers
1919
from rest_framework.generics import GenericAPIView
20+
from rest_framework.request import clone_request
2021
from rest_framework.response import Response
2122
from rest_framework.serializers import ValidationError
2223
from rest_framework.views import APIView
@@ -31,7 +32,7 @@
3132
from InvenTree.sso import sso_registration_enabled
3233
from plugin.serializers import MetadataSerializer
3334
from users.models import ApiToken
34-
from users.permissions import check_user_permission
35+
from users.permissions import check_user_permission, prefetch_rule_sets
3536

3637
from .helpers import plugins_info
3738
from .helpers_email import is_email_configured
@@ -767,6 +768,13 @@ def post(self, request, *args, **kwargs):
767768

768769
search_filters = self.get_result_filters()
769770

771+
# Create a clone of the request object to modify
772+
# Use GET method for the individual list views
773+
cloned_request = clone_request(request, 'GET')
774+
775+
# Fetch and cache all groups associated with the current user
776+
groups = prefetch_rule_sets(request.user)
777+
770778
for key, cls in self.get_result_types().items():
771779
# Only return results which are specifically requested
772780
if key in data:
@@ -790,22 +798,23 @@ def post(self, request, *args, **kwargs):
790798
view = cls()
791799

792800
# Override regular query params with specific ones for this search request
793-
request._request.GET = params
794-
view.request = request
801+
cloned_request._request.GET = params
802+
view.request = cloned_request
795803
view.format_kwarg = 'format'
796804

797805
# Check permissions and update results dict with particular query
798806
model = view.serializer_class.Meta.model
799807

808+
if not check_user_permission(
809+
request.user, model, 'view', groups=groups
810+
):
811+
results[key] = {
812+
'error': _('User does not have permission to view this model')
813+
}
814+
continue
815+
800816
try:
801-
if check_user_permission(request.user, model, 'view'):
802-
results[key] = view.list(request, *args, **kwargs).data
803-
else:
804-
results[key] = {
805-
'error': _(
806-
'User does not have permission to view this model'
807-
)
808-
}
817+
results[key] = view.list(request, *args, **kwargs).data
809818
except Exception as exc:
810819
results[key] = {'error': str(exc)}
811820

src/backend/InvenTree/InvenTree/serializers.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from rest_framework.exceptions import ValidationError
2222
from rest_framework.fields import empty
2323
from rest_framework.mixins import ListModelMixin
24+
from rest_framework.permissions import SAFE_METHODS
2425
from rest_framework.serializers import DecimalField
2526
from rest_framework.utils import model_meta
2627
from taggit.serializers import TaggitSerializer, TagListSerializerField
@@ -229,7 +230,7 @@ def do_filtering(self) -> None:
229230
# Skip filtering for a write requests - all fields should be present for data creation
230231
if request := self.context.get('request', None):
231232
if method := getattr(request, 'method', None):
232-
if str(method).lower() in ['post', 'put', 'patch'] and not is_exporting:
233+
if method not in SAFE_METHODS and not is_exporting:
233234
return
234235

235236
# Throw out fields which are not requested (either by default or explicitly)

src/backend/InvenTree/company/serializers.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -268,11 +268,7 @@ class Meta:
268268
source='part', many=False, read_only=True, allow_null=True
269269
),
270270
True,
271-
prefetch_fields=[
272-
Prefetch(
273-
'part', queryset=part.models.Part.objects.select_related('pricing_data')
274-
)
275-
],
271+
prefetch_fields=['part', 'part__pricing_data', 'part__category'],
276272
)
277273

278274
pretty_name = enable_filter(
@@ -438,7 +434,7 @@ def __init__(self, *args, **kwargs):
438434
label=_('Part'), source='part', many=False, read_only=True, allow_null=True
439435
),
440436
False,
441-
prefetch_fields=['part'],
437+
prefetch_fields=['part', 'part__pricing_data'],
442438
)
443439

444440
supplier_detail = enable_filter(

src/backend/InvenTree/company/test_api.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,14 @@
22

33
from django.urls import reverse
44

5-
from company.models import Address, Company, Contact, SupplierPart, SupplierPriceBreak
5+
from company.models import (
6+
Address,
7+
Company,
8+
Contact,
9+
ManufacturerPart,
10+
SupplierPart,
11+
SupplierPriceBreak,
12+
)
613
from InvenTree.unit_test import InvenTreeAPITestCase
714
from part.models import Part
815
from users.permissions import check_user_permission
@@ -498,7 +505,9 @@ def test_manufacturer_part_list(self):
498505

499506
def test_manufacturer_part_detail(self):
500507
"""Tests for the ManufacturerPart detail endpoint."""
501-
url = reverse('api-manufacturer-part-detail', kwargs={'pk': 1})
508+
mp = ManufacturerPart.objects.first()
509+
510+
url = reverse('api-manufacturer-part-detail', kwargs={'pk': mp.pk})
502511

503512
response = self.get(url)
504513
self.assertEqual(response.data['MPN'], 'MPN123')

src/backend/InvenTree/order/api.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -362,7 +362,9 @@ class PurchaseOrderOutputOptions(OutputConfiguration):
362362
class PurchaseOrderMixin(SerializerContextMixin):
363363
"""Mixin class for PurchaseOrder endpoints."""
364364

365-
queryset = models.PurchaseOrder.objects.all()
365+
queryset = models.PurchaseOrder.objects.all().prefetch_related(
366+
'supplier', 'created_by'
367+
)
366368
serializer_class = serializers.PurchaseOrderSerializer
367369

368370
def get_queryset(self, *args, **kwargs):
@@ -371,8 +373,6 @@ def get_queryset(self, *args, **kwargs):
371373

372374
queryset = serializers.PurchaseOrderSerializer.annotate_queryset(queryset)
373375

374-
queryset = queryset.prefetch_related('supplier', 'created_by')
375-
376376
return queryset
377377

378378

@@ -824,15 +824,15 @@ def filter_part(self, queryset, name, part):
824824
class SalesOrderMixin(SerializerContextMixin):
825825
"""Mixin class for SalesOrder endpoints."""
826826

827-
queryset = models.SalesOrder.objects.all()
827+
queryset = models.SalesOrder.objects.all().prefetch_related(
828+
'customer', 'created_by'
829+
)
828830
serializer_class = serializers.SalesOrderSerializer
829831

830832
def get_queryset(self, *args, **kwargs):
831833
"""Return annotated queryset for this endpoint."""
832834
queryset = super().get_queryset(*args, **kwargs)
833835

834-
queryset = queryset.prefetch_related('customer', 'created_by')
835-
836836
queryset = serializers.SalesOrderSerializer.annotate_queryset(queryset)
837837

838838
return queryset

src/backend/InvenTree/part/api.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1009,7 +1009,9 @@ class PartMixin(SerializerContextMixin):
10091009
"""Mixin class for Part API endpoints."""
10101010

10111011
serializer_class = part_serializers.PartSerializer
1012-
queryset = Part.objects.all().select_related('pricing_data')
1012+
queryset = (
1013+
Part.objects.all().select_related('pricing_data').prefetch_related('category')
1014+
)
10131015

10141016
starred_parts = None
10151017
is_create = False

src/backend/InvenTree/stock/serializers.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -489,14 +489,13 @@ def annotate_queryset(queryset):
489489
),
490490
'parent',
491491
'part__category',
492-
'part__pricing_data',
493492
'supplier_part',
494493
'supplier_part__manufacturer_part',
495494
'customer',
496495
'belongs_to',
497496
'sales_order',
498497
'consumed_by',
499-
).select_related('part')
498+
).select_related('part', 'part__pricing_data')
500499

501500
# Annotate the queryset with the total allocated to sales orders
502501
queryset = queryset.annotate(

src/backend/InvenTree/users/permissions.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,11 @@ def check_user_role(
130130

131131

132132
def check_user_permission(
133-
user: User, model: models.Model, permission: str, allow_inactive: bool = False
133+
user: User,
134+
model: models.Model,
135+
permission: str,
136+
allow_inactive: bool = False,
137+
groups: Optional[QuerySet] = None,
134138
) -> bool:
135139
"""Check if the user has a particular permission against a given model type.
136140
@@ -139,6 +143,7 @@ def check_user_permission(
139143
model: The model class to check (e.g. 'part')
140144
permission: The permission to check (e.g. 'view' / 'delete')
141145
allow_inactive: If False, disallow inactive users from having permissions
146+
groups: Optional cached queryset of groups to check (defaults to user's groups)
142147
143148
Returns:
144149
bool: True if the user has the specified permission
@@ -160,9 +165,11 @@ def check_user_permission(
160165
if table_name in get_ruleset_ignore():
161166
return True
162167

168+
groups = groups or prefetch_rule_sets(user)
169+
163170
for role, table_names in get_ruleset_models().items():
164171
if table_name in table_names:
165-
if check_user_role(user, role, permission):
172+
if check_user_role(user, role, permission, groups=groups):
166173
return True
167174

168175
# Check for children models which inherits from parent role
@@ -172,7 +179,7 @@ def check_user_permission(
172179

173180
if parent_child_string == table_name:
174181
# Check if parent role has change permission
175-
if check_user_role(user, parent, 'change'):
182+
if check_user_role(user, parent, 'change', groups=groups):
176183
return True
177184

178185
# Generate the permission name based on the model and permission

src/performance/tests.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,3 +88,73 @@ def test_api_options_performance(url):
8888
assert result
8989
assert 'actions' in result
9090
assert len(result['actions']) > 0
91+
92+
93+
@pytest.mark.benchmark
94+
@pytest.mark.parametrize(
95+
'key',
96+
[
97+
'all',
98+
'part',
99+
'partcategory',
100+
'supplierpart',
101+
'manufacturerpart',
102+
'stockitem',
103+
'stocklocation',
104+
'build',
105+
'supplier',
106+
'manufacturer',
107+
'customer',
108+
'purchaseorder',
109+
'salesorder',
110+
'salesordershipment',
111+
'returnorder',
112+
],
113+
)
114+
def test_search_performance(key: str):
115+
"""Benchmark the API search performance."""
116+
SEARCH_URL = '/api/search/'
117+
118+
# An indicative search query for various model types
119+
SEARCH_DATA = {
120+
'part': {'active': True},
121+
'partcategory': {},
122+
'supplierpart': {
123+
'part_detail': True,
124+
'supplier_detail': True,
125+
'manufacturer_detail': True,
126+
},
127+
'manufacturerpart': {
128+
'part_detail': True,
129+
'supplier_detail': True,
130+
'manufacturer_detail': True,
131+
},
132+
'stockitem': {'part_detail': True, 'location_detail': True, 'in_stock': True},
133+
'stocklocation': {},
134+
'build': {'part_detail': True},
135+
'supplier': {},
136+
'manufacturer': {},
137+
'customer': {},
138+
'purchaseorder': {'supplier_detail': True, 'outstanding': True},
139+
'salesorder': {'customer_detail': True, 'outstanding': True},
140+
'salesordershipment': {},
141+
'returnorder': {'customer_detail': True, 'outstanding': True},
142+
}
143+
144+
model_types = list(SEARCH_DATA.keys())
145+
146+
search_params = SEARCH_DATA if key == 'all' else {key: SEARCH_DATA[key]}
147+
148+
# Add in a common search term
149+
search_params.update({'search': '0', 'limit': 50})
150+
151+
response = api_client.post(SEARCH_URL, data=search_params)
152+
assert response
153+
154+
if key == 'all':
155+
for model_type in model_types:
156+
assert model_type in response
157+
assert 'error' not in response[model_type]
158+
else:
159+
assert key in response
160+
assert 'error' not in response[key]

0 commit comments

Comments
 (0)