Skip to content

Commit b0f7eeb

Browse files
authored
Adds delta logic (#144)
* Adds delta logic * Sbumodule update
1 parent a6778f3 commit b0f7eeb

File tree

2 files changed

+29
-13
lines changed

2 files changed

+29
-13
lines changed

controller.py

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def generate_batches(
4444
attribute_values_raw: List[str],
4545
embedder: Transformer,
4646
attribute_name: str,
47+
for_delta: bool = False,
4748
) -> Iterator[Dict[List[str], List[Any]]]:
4849
length = len(record_ids)
4950
record_batches = []
@@ -61,7 +62,10 @@ def generate_batches(
6162
record_batches.append(record_ids_batch)
6263
document_batches.extend(documents)
6364

64-
embedding_batches = embedder.fit_transform(document_batches, as_generator=True)
65+
if for_delta:
66+
embedding_batches = embedder.transform(document_batches, as_generator=True)
67+
else:
68+
embedding_batches = embedder.fit_transform(document_batches, as_generator=True)
6569
for record_batch in record_batches:
6670
yield {"record_ids": record_batch, "embeddings": next(embedding_batches)}
6771

@@ -193,6 +197,13 @@ def run_encoding(
193197
initial_count = record.count_attribute_list_entries(project_id, attribute_name)
194198
else:
195199
initial_count = record.count(project_id)
200+
201+
is_delta = False
202+
# refinery gateway handles delta logic beforehand so if count is 0 we can be sure it's not a delta
203+
if tensor_count := embedding.get_tensor_count(embedding_id) != 0:
204+
is_delta = True
205+
initial_count -= tensor_count
206+
196207
seed_str = embedding_name
197208
torch.manual_seed(zlib.adler32(bytes(seed_str, "utf-8")))
198209
notification.create(
@@ -214,15 +225,18 @@ def run_encoding(
214225
else:
215226
config_string = model
216227

217-
embedder = get_embedder(
218-
project_id,
219-
embedding_type,
220-
iso2_code,
221-
platform,
222-
model,
223-
api_token,
224-
additional_data,
225-
)
228+
if is_delta:
229+
embedder = __setup_tmp_embedder(project_id, embedding_id)
230+
else:
231+
embedder = get_embedder(
232+
project_id,
233+
embedding_type,
234+
iso2_code,
235+
platform,
236+
model,
237+
api_token,
238+
additional_data,
239+
)
226240

227241
if not embedder:
228242
raise Exception(
@@ -253,7 +267,7 @@ def run_encoding(
253267

254268
try:
255269
record_ids, attribute_values_raw = record.get_attribute_data(
256-
project_id, attribute_name
270+
project_id, attribute_name, is_delta, embedding_id
257271
)
258272
embedding.update_embedding_state_encoding(
259273
project_id,
@@ -279,7 +293,8 @@ def run_encoding(
279293
True,
280294
)
281295
send_project_update(project_id, f"notification_created:{user_id}", True)
282-
embedding.delete_tensors(embedding_id, with_commit=True)
296+
if not is_delta:
297+
embedding.delete_tensors(embedding_id, with_commit=True)
283298
chunk = 0
284299
embedding_canceled = False
285300
for pair in generate_batches(
@@ -289,6 +304,7 @@ def run_encoding(
289304
attribute_values_raw,
290305
embedder,
291306
attribute_name,
307+
for_delta=is_delta,
292308
):
293309
if chunk % 10 == 0:
294310
session_token = general.remove_and_refresh_session(session_token, True)

0 commit comments

Comments
 (0)