|
1 | 1 | from typing import List, Dict, Any, Optional |
| 2 | +import os, copy |
2 | 3 |
|
3 | 4 | from graphql_api.types import ExtendedSearch |
4 | 5 | from submodules.model import Record, Attribute |
|
7 | 8 | user_session, |
8 | 9 | embedding, |
9 | 10 | attribute, |
| 11 | + general, |
| 12 | + tokenization, |
| 13 | + task_queue, |
| 14 | + record_label_association, |
10 | 15 | ) |
11 | 16 | from service.search import search |
| 17 | +from submodules.model import enums |
12 | 18 |
|
| 19 | +from controller.embedding import connector as embedding_connector |
13 | 20 | from controller.record import neural_search_connector |
14 | 21 | from controller.embedding import manager as embedding_manager |
| 22 | +from controller.tokenization import tokenization_service |
15 | 23 | from util import daemon |
| 24 | +from util.miscellaneous_functions import chunk_list |
| 25 | +import time |
| 26 | +import traceback |
16 | 27 |
|
17 | 28 |
|
18 | 29 | def get_record(project_id: str, record_id: str) -> Record: |
@@ -113,3 +124,176 @@ def __reupload_embeddings(project_id: str) -> None: |
113 | 124 |
|
114 | 125 | def get_unique_values_by_attributes(project_id: str) -> Dict[str, List[str]]: |
115 | 126 | return attribute.get_unique_values_by_attributes(project_id) |
| 127 | + |
| 128 | + |
| 129 | +def edit_records( |
| 130 | + user_id: str, project_id: str, changes: Dict[str, Any] |
| 131 | +) -> Optional[List[str]]: |
| 132 | + prepped = __check_and_prep_edit_records(project_id, changes) |
| 133 | + if "errors_found" in prepped: |
| 134 | + return prepped["errors_found"] |
| 135 | + |
| 136 | + records = prepped["records"] |
| 137 | + |
| 138 | + for key in changes: |
| 139 | + record = records[changes[key]["recordId"]] |
| 140 | + # needs new object to detect changes for commit |
| 141 | + new_data = copy.deepcopy(record.data) |
| 142 | + if "subKey" in changes[key]: |
| 143 | + new_data[changes[key]["attributeName"]][changes[key]["subKey"]] = changes[ |
| 144 | + key |
| 145 | + ]["newValue"] |
| 146 | + else: |
| 147 | + new_data[changes[key]["attributeName"]] = changes[key]["newValue"] |
| 148 | + record.data = new_data |
| 149 | + general.commit() |
| 150 | + |
| 151 | + # remove labels |
| 152 | + for chunk in chunk_list(prepped["rla_delete_tuples"], 1): |
| 153 | + record_label_association.delete_by_record_attribute_tuples(project_id, chunk) |
| 154 | + |
| 155 | + general.commit() |
| 156 | + |
| 157 | + try: |
| 158 | + # tokenization currently with a complete rebuild of the docbins of touched records |
| 159 | + # optimization possible by only rebuilding the changed record & attribute combinations and reuploading |
| 160 | + tokenization.delete_record_docbins_by_id(project_id, records.keys(), True) |
| 161 | + tokenization.delete_token_statistics_by_id(project_id, records.keys(), True) |
| 162 | + tokenization_service.request_tokenize_project(project_id, user_id) |
| 163 | + time.sleep(1) |
| 164 | + # wait for tokenization to finish, the endpoint itself handles missing docbins |
| 165 | + while tokenization.is_doc_bin_creation_running(project_id): |
| 166 | + time.sleep(0.5) |
| 167 | + |
| 168 | + except Exception as e: |
| 169 | + __revert_record_data_changes(records, prepped["record_data_backup"]) |
| 170 | + print(traceback.format_exc(), flush=True) |
| 171 | + return ["tokenization failed"] |
| 172 | + |
| 173 | + try: |
| 174 | + embedding_connector.request_re_embed_records( |
| 175 | + project_id, prepped["embedding_rebuilds"] |
| 176 | + ) |
| 177 | + |
| 178 | + except Exception as e: |
| 179 | + __revert_record_data_changes(records, prepped["record_data_backup"]) |
| 180 | + print(traceback.format_exc(), flush=True) |
| 181 | + return ["embedding failed"] |
| 182 | + |
| 183 | + return None |
| 184 | + |
| 185 | + |
| 186 | +def __revert_record_data_changes( |
| 187 | + records: Dict[str, Record], data_backup: Dict[str, Any] |
| 188 | +) -> None: |
| 189 | + for record_id in data_backup: |
| 190 | + records[record_id].data = data_backup[record_id] |
| 191 | + general.commit() |
| 192 | + |
| 193 | + |
| 194 | +def __check_and_prep_edit_records( |
| 195 | + project_id: str, changes: Dict[str, Any] |
| 196 | +) -> Dict[str, Any]: |
| 197 | + # key example: <record_id>@<attribute_name>[@<sub_key>] |
| 198 | + |
| 199 | + errors_found = [] # list of strings |
| 200 | + useable_embeddings = {} # dict of UUID(attribute_id): [embedding_item] |
| 201 | + attributes = None # dict of attribute_name: attribute_item |
| 202 | + records = None # dict of str(record_id): record_item |
| 203 | + record_data_backup = None # dict of str(record_id): record_data |
| 204 | + embedding_rebuilds = {} # dict of str(embedding_id): [str(record_id)] |
| 205 | + record_ids = {changes[key]["recordId"] for key in changes} |
| 206 | + attribute_names = {changes[key]["attributeName"] for key in changes} |
| 207 | + |
| 208 | + records = record.get_by_record_ids(project_id, record_ids) |
| 209 | + if len(record_ids) != len(records): |
| 210 | + errors_found.append("can't match record ids to project") |
| 211 | + records = {str(r.id): r for r in records} |
| 212 | + |
| 213 | + attributes = attribute.get_all_by_names(project_id, attribute_names) |
| 214 | + if len(attribute_names) != len(attributes): |
| 215 | + errors_found.append("can't match attributes to project") |
| 216 | + attributes = {a.name: a for a in attributes} |
| 217 | + |
| 218 | + tmp = [ |
| 219 | + f"sub_key {changes[key]['subKey']} out of bounds for attribute {changes[key]['attributeName']} of record {changes[key]['recordId']}" |
| 220 | + for key in changes |
| 221 | + if "subKey" in changes[key] |
| 222 | + and changes[key]["subKey"] |
| 223 | + >= len(records[changes[key]["recordId"]].data[changes[key]["attributeName"]]) |
| 224 | + ] |
| 225 | + |
| 226 | + if tmp and len(tmp) > 0: |
| 227 | + errors_found += tmp |
| 228 | + |
| 229 | + # note that queues for embeddings will not be checked since they are not yet run so uninteresting for us here |
| 230 | + embeddings = embedding.get_all_by_attribute_ids( |
| 231 | + project_id, [a.id for a in attributes.values()] |
| 232 | + ) |
| 233 | + for embedding_item in embeddings: |
| 234 | + if embedding_item.state == enums.EmbeddingState.FAILED.value: |
| 235 | + # can be ignored since nothing exists to rebuild yet |
| 236 | + continue |
| 237 | + |
| 238 | + if embedding_item.state != enums.EmbeddingState.FINISHED.value: |
| 239 | + errors_found.append( |
| 240 | + f"embedding {embedding_item.name} is not finished. Wait for it to finish before editing records." |
| 241 | + ) |
| 242 | + continue |
| 243 | + |
| 244 | + emb_path = os.path.join( |
| 245 | + "/inference", project_id, f"embedder-{str(embedding_item.id)}.pkl" |
| 246 | + ) |
| 247 | + if not os.path.exists(emb_path): |
| 248 | + errors_found.append( |
| 249 | + f"can't find embedding PCA for {embedding_item.name}. Try rebuilding or removing the embeddings on settings page." |
| 250 | + ) |
| 251 | + continue |
| 252 | + if not embedding_item.attribute_id in useable_embeddings: |
| 253 | + useable_embeddings[embedding_item.attribute_id] = [] |
| 254 | + useable_embeddings[embedding_item.attribute_id].append(embedding_item) |
| 255 | + |
| 256 | + if tokenization.is_doc_bin_creation_running(project_id): |
| 257 | + errors_found.append( |
| 258 | + "tokenization is currently running. Wait for it to finish before editing records." |
| 259 | + ) |
| 260 | + |
| 261 | + if task_queue.get_by_tokenization(project_id) is not None: |
| 262 | + errors_found.append( |
| 263 | + "tokenization is currently queued. Wait for it to finish before editing records." |
| 264 | + ) |
| 265 | + |
| 266 | + if errors_found: |
| 267 | + return {"errors_found": errors_found} |
| 268 | + |
| 269 | + record_data_backup = {str(r.id): copy.deepcopy(r.data) for r in records.values()} |
| 270 | + rla_delete_tuples = [ |
| 271 | + (c["recordId"], str(attributes[c["attributeName"]].id)) |
| 272 | + for c in changes.values() |
| 273 | + if "subKey" not in c |
| 274 | + and attributes[c["attributeName"]].data_type == enums.DataTypes.TEXT.value |
| 275 | + ] |
| 276 | + |
| 277 | + if len(useable_embeddings) > 0: |
| 278 | + for change in changes.values(): |
| 279 | + attribute_id = attributes[change["attributeName"]].id |
| 280 | + if attribute_id not in useable_embeddings: |
| 281 | + continue |
| 282 | + for embedding_item in useable_embeddings[attribute_id]: |
| 283 | + embedding_id = str(embedding_item.id) |
| 284 | + if embedding_id not in embedding_rebuilds: |
| 285 | + embedding_rebuilds[embedding_id] = [] |
| 286 | + changed_record_info = { |
| 287 | + "record_id": change["recordId"], |
| 288 | + "attribute_name": change["attributeName"], |
| 289 | + } |
| 290 | + if "subKey" in change: |
| 291 | + changed_record_info["sub_key"] = change["subKey"] |
| 292 | + embedding_rebuilds[embedding_id].append(changed_record_info) |
| 293 | + |
| 294 | + return { |
| 295 | + "records": records, |
| 296 | + "record_data_backup": record_data_backup, |
| 297 | + "rla_delete_tuples": rla_delete_tuples, |
| 298 | + "embedding_rebuilds": embedding_rebuilds, |
| 299 | + } |
0 commit comments