@@ -38,6 +38,47 @@ def _get_sk_ids_for_type(pointer_type: str) -> tuple[str, str]:
3838 return category_id , type_id
3939
4040
41+ def _get_categories_for_pointer_types (pointer_types : list [str ]) -> set [str ]:
42+ """Return all unique categories for the given pointer types."""
43+ category_set = set ()
44+ for pointer_type in pointer_types :
45+ if pointer_type in TYPE_CATEGORIES :
46+ cats = TYPE_CATEGORIES [pointer_type ]
47+ if isinstance (cats , str ):
48+ category_set .add (cats )
49+ else :
50+ category_set .update (cats )
51+ return category_set
52+
53+
54+ def _build_filter_expressions (
55+ pointer_types , categories , expression_names , expression_values
56+ ):
57+ """Build DynamoDB filter expressions for pointer_types and categories."""
58+ filter_expressions = []
59+ if pointer_types :
60+ expression_names ["#pointer_type" ] = "type"
61+ types_filters = [
62+ f"#pointer_type = :type_{ i } " for i in range (len (pointer_types ))
63+ ]
64+ types_filter_values = {
65+ f":type_{ i } " : pointer_types [i ] for i in range (len (pointer_types ))
66+ }
67+ filter_expressions .append (f"({ ' OR ' .join (types_filters )} )" )
68+ expression_values .update (types_filter_values )
69+ if categories :
70+ expression_names ["#category" ] = "category"
71+ category_filters = [
72+ f"#category = :category_{ i } " for i in range (len (categories ))
73+ ]
74+ category_filter_values = {
75+ f":category_{ i } " : categories [i ] for i in range (len (categories ))
76+ }
77+ filter_expressions .append (f"({ ' OR ' .join (category_filters )} )" )
78+ expression_values .update (category_filter_values )
79+ return filter_expressions
80+
81+
4182class Repository (ABC , Generic [RepositoryModel ]):
4283 ITEM_TYPE : Type [RepositoryModel ]
4384
@@ -163,7 +204,7 @@ def count_by_nhs_number(
163204
164205 if len (pointer_types ) == 1 :
165206 # Optimisation for single pointer type
166- category_id , type_id = _get_sk_ids_for_type (pointer_types [0 ])[ 0 ]
207+ category_id , type_id = _get_sk_ids_for_type (pointer_types [0 ])
167208 patient_sort = f"C#{ category_id } #T#{ type_id } "
168209 key_conditions .append ("begins_with(patient_sort, :patient_sort)" )
169210 expression_values [":patient_sort" ] = patient_sort
@@ -225,7 +266,6 @@ def search(
225266 pointer_types : Optional [List [str ]] = [],
226267 categories : Optional [List [str ]] = [],
227268 ) -> Iterator [DocumentPointer ]:
228- """"""
229269 logger .log (
230270 LogReference .REPOSITORY020 ,
231271 nhs_number = nhs_number ,
@@ -239,67 +279,23 @@ def search(
239279 expression_names = {}
240280 expression_values = {":patient_key" : f"P#{ nhs_number } " }
241281
242- # If both categories and pointer_types are provided, filter on both
282+ # Determine which filters to apply
243283 if pointer_types and categories :
244- expression_names ["#pointer_type" ] = "type"
245- expression_names ["#category" ] = "category"
246- types_filters = [
247- f"#pointer_type = :type_{ i } " for i in range (len (pointer_types ))
248- ]
249- types_filter_values = {
250- f":type_{ i } " : pointer_types [i ] for i in range (len (pointer_types ))
251- }
252- category_filters = [
253- f"#category = :category_{ i } " for i in range (len (categories ))
254- ]
255- category_filter_values = {
256- f":category_{ i } " : categories [i ] for i in range (len (categories ))
257- }
258- filter_expressions .append (f"({ ' OR ' .join (types_filters )} )" )
259- filter_expressions .append (f"({ ' OR ' .join (category_filters )} )" )
260- expression_values .update (types_filter_values )
261- expression_values .update (category_filter_values )
262-
263- # If only pointer_types are provided, retrieve all categories for each type and filter on both
284+ # Use both pointer_types and categories as filters
285+ filter_expressions = _build_filter_expressions (
286+ pointer_types , categories , expression_names , expression_values
287+ )
264288 elif pointer_types and not categories :
265- expression_names ["#pointer_type" ] = "type"
266- expression_names ["#category" ] = "category"
267- types_filters = []
268- category_filters = []
269- types_filter_values = {}
270- category_filter_values = {}
271- category_set = set ()
272- for i , pointer_type in enumerate (pointer_types ):
273- types_filters .append (f"#pointer_type = :type_{ i } " )
274- types_filter_values [f":type_{ i } " ] = pointer_type
275- # Get all categories for this type, handling both set and string
276- if pointer_type in TYPE_CATEGORIES :
277- cats = TYPE_CATEGORIES [pointer_type ]
278- if isinstance (cats , str ):
279- category_set .add (cats )
280- else :
281- category_set .update (cats )
282- for j , cat in enumerate (category_set ):
283- category_filters .append (f"#category = :category_{ j } " )
284- category_filter_values [f":category_{ j } " ] = cat
285- if types_filters :
286- filter_expressions .append (f"({ ' OR ' .join (types_filters )} )" )
287- if category_filters :
288- filter_expressions .append (f"({ ' OR ' .join (category_filters )} )" )
289- expression_values .update (types_filter_values )
290- expression_values .update (category_filter_values )
291-
292- # If only categories are provided, filter on categories
289+ # Get all categories for these pointer_types
290+ all_categories = list (_get_categories_for_pointer_types (pointer_types ))
291+ filter_expressions = _build_filter_expressions (
292+ pointer_types , all_categories , expression_names , expression_values
293+ )
293294 elif categories and not pointer_types :
294- expression_names ["#category" ] = "category"
295- category_filters = [
296- f"#category = :category_{ i } " for i in range (len (categories ))
297- ]
298- category_filter_values = {
299- f":category_{ i } " : categories [i ] for i in range (len (categories ))
300- }
301- filter_expressions .append (f"({ ' OR ' .join (category_filters )} )" )
302- expression_values .update (category_filter_values )
295+ # Only categories provided
296+ filter_expressions = _build_filter_expressions (
297+ [], categories , expression_names , expression_values
298+ )
303299
304300 if custodian :
305301 logger .log (
0 commit comments