4
4
from functools import partial
5
5
from itertools import chain
6
6
import logging
7
+ from typing import Callable , Iterable , Optional
7
8
8
9
from algoliasearch .http .exceptions import AlgoliaException
9
10
from algoliasearch .search .models .operation_index_params import OperationIndexParams
@@ -57,7 +58,7 @@ class AlgoliaIndex(object):
57
58
tags = None
58
59
59
60
# Use to specify the index to target on Algolia.
60
- index_name = None
61
+ index_name : str | None = None
61
62
62
63
# Use to specify the settings of the index.
63
64
settings = None
@@ -72,9 +73,27 @@ class AlgoliaIndex(object):
72
73
# Name of the attribute to check on instances if should_index is not a callable
73
74
_should_index_is_method = False
74
75
76
+ get_queryset : Optional [Callable [[], Iterable ]] = None
77
+
75
78
def __init__ (self , model , client , settings ):
76
79
"""Initializes the index."""
77
- self .__init_index (model , settings )
80
+ if not self .index_name :
81
+ self .index_name = model .__name__
82
+
83
+ tmp_index_name = "{index_name}_tmp" .format (index_name = self .index_name )
84
+
85
+ if "INDEX_PREFIX" in settings :
86
+ self .index_name = settings ["INDEX_PREFIX" ] + "_" + self .index_name
87
+ tmp_index_name = "{index_prefix}_{tmp_index_name}" .format (
88
+ tmp_index_name = tmp_index_name , index_prefix = settings ["INDEX_PREFIX" ]
89
+ )
90
+ if "INDEX_SUFFIX" in settings :
91
+ self .index_name += "_" + settings ["INDEX_SUFFIX" ]
92
+ tmp_index_name = "{tmp_index_name}_{index_suffix}" .format (
93
+ tmp_index_name = tmp_index_name , index_suffix = settings ["INDEX_SUFFIX" ]
94
+ )
95
+
96
+ self .tmp_index_name = tmp_index_name
78
97
79
98
self .model = model
80
99
self .__client = client
@@ -170,25 +189,6 @@ def __init__(self, model, client, settings):
170
189
)
171
190
)
172
191
173
- def __init_index (self , model , settings ):
174
- if not self .index_name :
175
- self .index_name = model .__name__
176
-
177
- tmp_index_name = "{index_name}_tmp" .format (index_name = self .index_name )
178
-
179
- if "INDEX_PREFIX" in settings :
180
- self .index_name = settings ["INDEX_PREFIX" ] + "_" + self .index_name
181
- tmp_index_name = "{index_prefix}_{tmp_index_name}" .format (
182
- tmp_index_name = tmp_index_name , index_prefix = settings ["INDEX_PREFIX" ]
183
- )
184
- if "INDEX_SUFFIX" in settings :
185
- self .index_name += "_" + settings ["INDEX_SUFFIX" ]
186
- tmp_index_name = "{tmp_index_name}_{index_suffix}" .format (
187
- tmp_index_name = tmp_index_name , index_suffix = settings ["INDEX_SUFFIX" ]
188
- )
189
-
190
- self .tmp_index_name = tmp_index_name
191
-
192
192
@staticmethod
193
193
def _validate_geolocation (geolocation ):
194
194
"""
@@ -239,7 +239,7 @@ def get_raw_record(self, instance, update_fields=None):
239
239
if callable (self .tags ):
240
240
tmp ["_tags" ] = self .tags (instance )
241
241
if not isinstance (tmp ["_tags" ], list ):
242
- tmp ["_tags" ] = list (tmp ["_tags" ])
242
+ tmp ["_tags" ] = list (tmp ["_tags" ]) # pyright: ignore
243
243
244
244
logger .debug ("BUILD %s FROM %s" , tmp ["objectID" ], self .model )
245
245
return tmp
@@ -374,12 +374,7 @@ def update_records(self, qs, batch_size=1000, **kwargs):
374
374
tmp ["objectID" ] = elt
375
375
batch .append (dict (tmp ))
376
376
377
- if len (batch ) >= batch_size :
378
- self .__client .partial_update_objects (
379
- index_name = self .index_name , objects = batch , wait_for_tasks = True
380
- )
381
- batch = []
382
-
377
+ # TODO: pass batch_size to partial_update_objects
383
378
if len (batch ) > 0 :
384
379
self .__client .partial_update_objects (
385
380
index_name = self .index_name , objects = batch , wait_for_tasks = True
@@ -519,9 +514,10 @@ def reindex_all(self, batch_size=1000):
519
514
520
515
counts = 0
521
516
batch = []
517
+ qs = []
522
518
523
- if hasattr (self , "get_queryset" ) and callable (self .get_queryset ): # pyright: ignore
524
- qs = self .get_queryset () # pyright: ignore
519
+ if hasattr (self , "get_queryset" ) and callable (self .get_queryset ):
520
+ qs = self .get_queryset ()
525
521
else :
526
522
qs = self .model .objects .all ()
527
523
@@ -550,7 +546,8 @@ def reindex_all(self, batch_size=1000):
550
546
_resp = self .__client .operation_index (
551
547
self .tmp_index_name ,
552
548
OperationIndexParams (
553
- operation = OperationType .MOVE , destination = self .index_name
549
+ operation = OperationType .MOVE ,
550
+ destination = self .index_name , # pyright: ignore
554
551
),
555
552
)
556
553
self .__client .wait_for_task (self .tmp_index_name , _resp .task_id )
0 commit comments