Skip to content

Commit 8a0f8d2

Browse files
committed
Continue refactoring
1 parent a873b62 commit 8a0f8d2

File tree

2 files changed

+166
-92
lines changed

2 files changed

+166
-92
lines changed

ansible_base/rbac/claims.py

Lines changed: 111 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import hashlib
22
import json
3+
from collections import defaultdict
34
from typing import Type, Union
45

56
from django.apps import apps
@@ -100,82 +101,54 @@ def get_user_object_roles(user: Model) -> list[tuple[str, str, int]]:
100101
return _format_role_assignment_results(assignment_queryset)
101102

102103

103-
def _build_organization_data(org_cls: Type[Model], claims: dict, required_data: dict[str, dict]) -> str:
104-
"""Build organization data for claims processing.
104+
def _load_needed_objects(needed_objects: dict[str, set[str]]) -> dict[str, dict[str, dict]]:
105+
"""Load only the specific objects needed for claims processing.
105106
106107
Args:
107-
org_cls: Organization model class
108-
claims: Claims dictionary to populate
109-
required_data: Required data cache to populate
108+
needed_objects: Dict mapping content_type_model -> set of ansible_ids needed
110109
111110
Returns:
112-
String representing the organization content type model name
111+
Dict mapping content_type_model -> ansible_id -> object_data
113112
"""
114-
org_content_type_model = DABContentType.objects.get_for_model(org_cls).model
115-
required_data[org_content_type_model] = {}
113+
objs_by_ansible_id = {}
116114

117-
# Populate required_data for organizations
118-
for org in org_cls.objects.all().values('id', 'name', 'resource__ansible_id'):
119-
org_id = org['id']
120-
name = org['name']
121-
ansible_id = str(org['resource__ansible_id'])
122-
org_data = {'ansible_id': ansible_id, 'name': name}
115+
# Load organizations if needed
116+
org_content_type_model = DABContentType.objects.get_for_model(get_organization_model()).model
117+
if org_content_type_model in needed_objects:
118+
org_ansible_ids = needed_objects[org_content_type_model]
119+
objs_by_ansible_id[org_content_type_model] = {}
123120

124-
# Store by both id and ansible_id for flexible lookup
125-
required_data[org_content_type_model][org_id] = org_data
126-
required_data[org_content_type_model][ansible_id] = org_data
121+
org_cls = get_organization_model()
122+
for org in org_cls.objects.filter(resource__ansible_id__in=org_ansible_ids).values('name', 'resource__ansible_id'):
123+
ansible_id = str(org['resource__ansible_id'])
124+
objs_by_ansible_id[org_content_type_model][ansible_id] = {'ansible_id': ansible_id, 'name': org['name']}
127125

128-
claims['objects'][org_content_type_model] = []
129-
return org_content_type_model
126+
# Load teams if needed
127+
team_content_type_model = DABContentType.objects.get_for_model(get_team_model()).model
128+
if team_content_type_model in needed_objects:
129+
team_ansible_ids = needed_objects[team_content_type_model]
130+
objs_by_ansible_id[team_content_type_model] = {}
130131

132+
team_cls = get_team_model()
133+
for team in team_cls.objects.filter(resource__ansible_id__in=team_ansible_ids).values(
134+
'name', 'resource__ansible_id', 'organization__resource__ansible_id'
135+
):
136+
ansible_id = str(team['resource__ansible_id'])
137+
org_ansible_id = str(team['organization__resource__ansible_id'])
138+
objs_by_ansible_id[team_content_type_model][ansible_id] = {'ansible_id': ansible_id, 'name': team['name'], 'org': org_ansible_id}
131139

132-
def _build_team_data(team_cls: Type[Model], claims: dict, required_data: dict[str, dict]) -> str:
133-
"""Build team data for claims processing.
134-
135-
Args:
136-
team_cls: Team model class
137-
claims: Claims dictionary to populate
138-
required_data: Required data cache to populate
139-
140-
Returns:
141-
String representing the team content type model name
142-
"""
143-
team_content_type_model = DABContentType.objects.get_for_model(team_cls).model
144-
required_data[team_content_type_model] = {}
145-
146-
# Populate required_data for teams
147-
for team in team_cls.objects.all().values('id', 'name', 'resource__ansible_id', 'organization__resource__ansible_id'):
148-
team_id = team['id']
149-
team_name = team['name']
150-
ansible_id = str(team['resource__ansible_id'])
151-
related_org_ansible_id = str(team['organization__resource__ansible_id'])
152-
team_data = {'ansible_id': ansible_id, 'name': team_name, 'org': related_org_ansible_id}
153-
154-
# Store by both id and ansible_id for flexible lookup
155-
required_data[team_content_type_model][team_id] = team_data
156-
required_data[team_content_type_model][ansible_id] = team_data
157-
158-
claims['objects'][team_content_type_model] = []
159-
return team_content_type_model
140+
return objs_by_ansible_id
160141

161142

162143
def _process_user_object_roles(
163144
user: Model,
164-
org_content_type_model: str,
165-
team_content_type_model: str,
166-
cached_objects_index: dict[str, dict],
167-
cached_content_types: dict[int, str],
168-
required_data: dict[str, dict],
145+
cached_objects_index: defaultdict[str, dict],
169146
) -> tuple[dict[str, list], dict[str, dict[str, Union[str, list[int]]]]]:
170147
"""Process user's object-scoped role assignments and return objects and roles data.
171148
172149
Args:
173150
user: User model instance
174-
org_content_type_model: String name of organization content type model
175-
team_content_type_model: String name of team content type model
176151
cached_objects_index: Cache mapping content_model -> ansible_id -> array_index (will be modified)
177-
cached_content_types: Cache mapping content_type_id -> model_name
178-
required_data: Cache containing object data by content_model and ansible_id
179152
180153
Returns:
181154
Tuple containing:
@@ -188,22 +161,39 @@ def _process_user_object_roles(
188161
{'Organization Admin': {'content_type': 'organization', 'objects': [0]}}
189162
)
190163
"""
164+
# Get content type models for organizations and teams
165+
org_content_type_model = DABContentType.objects.get_for_model(get_organization_model()).model
166+
team_content_type_model = DABContentType.objects.get_for_model(get_team_model()).model
167+
191168
# Initialize objects dict with empty arrays
192169
objects_dict = {org_content_type_model: [], team_content_type_model: []}
193170

194171
user_object_roles = get_user_object_roles(user)
195-
object_roles = {}
196172

173+
# First pass: identify what objects we need
174+
needed_objects = defaultdict(set)
175+
for role_name, ansible_id, content_type_id in user_object_roles:
176+
content_model_type = DABContentType.objects.get_for_id(content_type_id).model
177+
needed_objects[content_model_type].add(ansible_id)
178+
179+
# Load only the objects we actually need
180+
objs_by_ansible_id = _load_needed_objects(needed_objects)
181+
182+
# Second pass: build objects_dict and object_roles
183+
object_roles = {}
197184
for role_name, ansible_id, content_type_id in user_object_roles:
198-
# Get the model for this content_type
199-
content_model_type = cached_content_types[content_type_id]
185+
content_model_type = DABContentType.objects.get_for_id(content_type_id).model
186+
187+
# Ensure the content_model_type exists in objects_dict (in case of new types)
188+
if content_model_type not in objects_dict:
189+
objects_dict[content_model_type] = []
200190

201191
# If the ansible_id is not in the cached_objects_index
202192
if ansible_id not in cached_objects_index[content_model_type]:
203193
# Cache the index (current len will be the next index when we append)
204194
cached_objects_index[content_model_type][ansible_id] = len(objects_dict[content_model_type])
205195
# Add the object to the objects dict
206-
objects_dict[content_model_type].append(required_data[content_model_type][ansible_id])
196+
objects_dict[content_model_type].append(objs_by_ansible_id[content_model_type][ansible_id])
207197

208198
# Get the index value from the cache
209199
object_index = cached_objects_index[content_model_type][ansible_id]
@@ -218,12 +208,34 @@ def _process_user_object_roles(
218208
return objects_dict, object_roles
219209

220210

211+
def _pivot_objects_by_ansible_id(objects_dict: dict[str, list]) -> dict[str, dict[str, dict]]:
212+
"""Convert objects_dict to a lookup dictionary indexed by ansible_id.
213+
214+
Args:
215+
objects_dict: Dictionary with content_model -> list of objects
216+
217+
Returns:
218+
Dictionary mapping content_model -> ansible_id -> object_data
219+
220+
Example:
221+
Input: {'organization': [{'ansible_id': 'uuid1', 'name': 'Org1'}]}
222+
Output: {'organization': {'uuid1': {'ansible_id': 'uuid1', 'name': 'Org1'}}}
223+
"""
224+
objs_by_ansible_id = {}
225+
226+
for content_model_type, objects_list in objects_dict.items():
227+
objs_by_ansible_id[content_model_type] = {}
228+
for obj_data in objects_list:
229+
ansible_id = obj_data['ansible_id']
230+
objs_by_ansible_id[content_model_type][ansible_id] = obj_data
231+
232+
return objs_by_ansible_id
233+
234+
221235
def _fix_team_organization_references(
222236
objects_dict: dict[str, list],
223-
team_content_type_model: str,
224-
org_content_type_model: str,
225-
cached_objects_index: dict[str, dict],
226-
required_data: dict[str, dict],
237+
cached_objects_index: defaultdict[str, dict],
238+
objs_by_ansible_id: dict[str, dict[str, dict]],
227239
) -> None:
228240
"""Convert team organization references from ansible_ids to array indexes.
229241
@@ -232,11 +244,31 @@ def _fix_team_organization_references(
232244
233245
Args:
234246
objects_dict: Dictionary with content_model -> list of objects (will be modified)
235-
team_content_type_model: String name of team content type model
236-
org_content_type_model: String name of organization content type model
237247
cached_objects_index: Cache mapping content_model -> ansible_id -> array_index (will be modified)
238-
required_data: Cache containing object data by content_model and ansible_id
248+
objs_by_ansible_id: Dictionary mapping content_model -> ansible_id -> object_data
239249
"""
250+
# Get content type models for organizations and teams
251+
org_content_type_model = DABContentType.objects.get_for_model(get_organization_model()).model
252+
team_content_type_model = DABContentType.objects.get_for_model(get_team_model()).model
253+
254+
# Only process if there are teams in the objects dict
255+
if team_content_type_model not in objects_dict:
256+
return
257+
258+
# Collect any missing org ansible_ids that we need to load
259+
missing_org_ansible_ids = set()
260+
for team in objects_dict[team_content_type_model]:
261+
org_ansible_id = team['org']
262+
if org_ansible_id not in cached_objects_index[org_content_type_model]:
263+
missing_org_ansible_ids.add(org_ansible_id)
264+
265+
# Load any missing organizations
266+
if missing_org_ansible_ids:
267+
missing_orgs = _load_needed_objects({org_content_type_model: missing_org_ansible_ids})
268+
if org_content_type_model in missing_orgs:
269+
objs_by_ansible_id.setdefault(org_content_type_model, {}).update(missing_orgs[org_content_type_model])
270+
271+
# Now convert ansible_ids to indexes
240272
for team in objects_dict[team_content_type_model]:
241273
org_ansible_id = team['org']
242274

@@ -247,7 +279,7 @@ def _fix_team_organization_references(
247279
# Organization not yet in objects - add it
248280
org_index = len(objects_dict[org_content_type_model])
249281
cached_objects_index[org_content_type_model][org_ansible_id] = org_index
250-
org_data = required_data[org_content_type_model][org_ansible_id]
282+
org_data = objs_by_ansible_id[org_content_type_model][org_ansible_id]
251283
team['org'] = org_index
252284
objects_dict[org_content_type_model].append(org_data)
253285

@@ -298,33 +330,20 @@ def get_user_claims(user: Model) -> dict[str, Union[list[str], dict[str, Union[s
298330
'global_roles': ['Platform Auditor']
299331
}
300332
"""
333+
# Warm the DABContentType cache for efficient lookups
334+
DABContentType.objects.warm_cache()
335+
301336
# Initialize caching dictionaries
302-
cached_objects_index = {} # { <content_model>: {<ansible_id>: <array index integer> } }
303-
cached_content_types = {} # { <content id integer>: <content_model> }
304-
required_data = {} # { <content_model>: { <ansible_id>|<id>: <required_data> } }
305-
306-
# Build content type caches
307-
for content_type in DABContentType.objects.all().values('id', 'model'):
308-
content_type_id = content_type['id']
309-
model = content_type['model']
310-
cached_content_types[content_type_id] = model
311-
cached_objects_index[model] = {}
312-
313-
# Get model classes
314-
org_cls = get_organization_model()
315-
team_cls = get_team_model()
316-
317-
# Build organization and team data caches
318-
org_content_type_model = _build_organization_data(org_cls, {'objects': {}}, required_data)
319-
team_content_type_model = _build_team_data(team_cls, {'objects': {}}, required_data)
320-
321-
# Process user's object role assignments
322-
objects_dict, object_roles = _process_user_object_roles(
323-
user, org_content_type_model, team_content_type_model, cached_objects_index, cached_content_types, required_data
324-
)
337+
cached_objects_index = defaultdict(dict) # { <content_model>: {<ansible_id>: <array index integer> } }
338+
339+
# Process user's object role assignments (loads only needed objects)
340+
objects_dict, object_roles = _process_user_object_roles(user, cached_objects_index)
341+
342+
# Create lookup dictionary by ansible_id for organization reference resolution
343+
objs_by_ansible_id = _pivot_objects_by_ansible_id(objects_dict)
325344

326345
# Convert team organization references from ansible_ids to indexes
327-
_fix_team_organization_references(objects_dict, team_content_type_model, org_content_type_model, cached_objects_index, required_data)
346+
_fix_team_organization_references(objects_dict, cached_objects_index, objs_by_ansible_id)
328347

329348
# Get global roles
330349
global_roles = _get_user_global_roles(user)

test_app/tests/rbac/test_claims.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import pytest
22
from django.contrib.auth import get_user_model
3+
from django.db import connection
4+
from django.test.utils import override_settings
35

46
from ansible_base.rbac import permission_registry
57
from ansible_base.rbac.claims import get_claims_hash, get_user_claims, get_user_claims_hashable_form
@@ -349,3 +351,56 @@ def test_identical_permissions_same_hash(self, claims_scenario, scenario_name):
349351
# Verify hash is a valid SHA-256 hex string (64 characters)
350352
assert len(user1_hash) == 64
351353
assert all(c in '0123456789abcdef' for c in user1_hash)
354+
355+
@override_settings(DEBUG=True)
356+
def test_claims_query_performance_baseline(self, claims_scenario):
357+
"""Performance test to measure database queries for claims generation.
358+
359+
This test establishes a baseline for the number of database queries
360+
used when generating claims for the 'mixed_large' scenario, which is
361+
one of the most complex permission scenarios.
362+
363+
Scenario details:
364+
- Organization Admin for 5 organizations (indexes 0,1,2,3,4)
365+
- Team Member for 10 teams (indexes 0,2,4,6,8,10,12,14,16,18)
366+
- Platform Auditor global role
367+
"""
368+
scenario_name = 'mixed_large'
369+
370+
# Create user and apply the complex scenario
371+
user = get_user_model().objects.create(username='test_user_performance')
372+
claims_scenario.apply_scenario(scenario_name, user)
373+
374+
# Clear any existing queries from setup
375+
connection.queries_log.clear()
376+
377+
# Count queries before claims generation
378+
queries_before = len(connection.queries)
379+
380+
# Generate claims (this is what we're measuring)
381+
user_claims = get_user_claims(user)
382+
383+
# Count queries after claims generation
384+
queries_after = len(connection.queries)
385+
total_queries = queries_after - queries_before
386+
387+
# Verify we got valid claims (basic sanity check)
388+
assert isinstance(user_claims, dict)
389+
assert 'objects' in user_claims
390+
assert 'object_roles' in user_claims
391+
assert 'global_roles' in user_claims
392+
393+
# Report the baseline query count
394+
print("\n=== CLAIMS QUERY PERFORMANCE BASELINE ===")
395+
print(f"Scenario: {scenario_name}")
396+
print(f"Total database queries: {total_queries}")
397+
print("Query details:")
398+
for i, query in enumerate(connection.queries[queries_before:], 1):
399+
print(f" {i}. {query['sql'][:100]}{'...' if len(query['sql']) > 100 else ''}")
400+
print(f" Time: {query['time']}s")
401+
print("=" * 45)
402+
403+
# Assert we maintain our performance baseline
404+
# Baseline updated after optimization: 6 queries for the mixed_large scenario
405+
# (More efficient: only loads objects user has access to, not ALL objects)
406+
assert total_queries == 6, f"Claims generation used {total_queries} queries, expected 6 (baseline)"

0 commit comments

Comments
 (0)