|
9 | 9 |
|
10 | 10 | import asyncio |
11 | 11 | import functools |
| 12 | +import gc |
12 | 13 | import logging |
13 | 14 | import os |
14 | 15 | import sys |
@@ -256,6 +257,35 @@ def __init__(self, **data): |
256 | 257 | data["embeddings_path"] = Path(data["embeddings_path"]).resolve() |
257 | 258 | super().__init__(**data) |
258 | 259 |
|
| 260 | + @staticmethod |
| 261 | + def _cleanup_cuda_memory(device: str) -> None: |
| 262 | + """ |
| 263 | + Clean up CUDA memory by clearing cache and forcing garbage collection. |
| 264 | + |
| 265 | + This completely frees GPU VRAM to ensure it returns to zero (or minimal baseline) |
| 266 | + after operations. The model will need to be reloaded on subsequent operations, |
| 267 | + but this ensures GPU memory is available for other processes. |
| 268 | + |
| 269 | + Note: A baseline CUDA context (~188 MiB) may remain after first GPU use. |
| 270 | + This is a PyTorch/CUDA limitation and cannot be freed without ending the process. |
| 271 | + |
| 272 | + Args: |
| 273 | + device: The device string ("cuda" or "cpu") |
| 274 | + """ |
| 275 | + if device == "cuda": |
| 276 | + try: |
| 277 | + # Synchronize to ensure all CUDA operations are complete |
| 278 | + torch.cuda.synchronize() |
| 279 | + # Empty the CUDA cache |
| 280 | + torch.cuda.empty_cache() |
| 281 | + # Force garbage collection to clean up Python references |
| 282 | + gc.collect() |
| 283 | + # Empty cache again after GC to catch any newly freed memory |
| 284 | + torch.cuda.empty_cache() |
| 285 | + except RuntimeError as e: |
| 286 | + # Log but don't crash if CUDA operations fail |
| 287 | + logger.warning(f"CUDA cleanup failed: {e}") |
| 288 | + |
259 | 289 | def get_image_files_from_directory( |
260 | 290 | self, |
261 | 291 | directory: Path, |
@@ -447,15 +477,22 @@ def _process_images_batch( |
447 | 477 | umap_embeddings = self.create_umap_index( |
448 | 478 | np.array(embeddings) if embeddings else np.empty((0, 512)) |
449 | 479 | ) |
450 | | - |
451 | | - return IndexResult( |
| 480 | + |
| 481 | + result = IndexResult( |
452 | 482 | embeddings=np.array(embeddings) if embeddings else np.empty((0, 512)), |
453 | 483 | filenames=np.array(filenames), |
454 | 484 | modification_times=np.array(modification_times), |
455 | 485 | metadata=np.array(metadatas, dtype=object), |
456 | 486 | umap_embeddings=umap_embeddings, |
457 | 487 | bad_files=bad_files, |
458 | 488 | ) |
| 489 | + |
| 490 | + # Clean up GPU memory after batch processing |
| 491 | + # Delete model references to completely free VRAM |
| 492 | + del model, preprocess |
| 493 | + self._cleanup_cuda_memory(device) |
| 494 | + |
| 495 | + return result |
459 | 496 |
|
460 | 497 | async def _process_images_batch_async( |
461 | 498 | self, image_paths: list[Path], album_key: str, yield_interval: int = 10 |
@@ -498,13 +535,20 @@ async def _process_images_batch_async( |
498 | 535 | if i % yield_interval == 0: |
499 | 536 | await asyncio.sleep(0.01) |
500 | 537 |
|
501 | | - return IndexResult( |
| 538 | + result = IndexResult( |
502 | 539 | embeddings=np.array(embeddings) if embeddings else np.empty((0, 512)), |
503 | 540 | filenames=np.array(filenames), |
504 | 541 | modification_times=np.array(modification_times), |
505 | 542 | metadata=np.array(metadatas, dtype=object), |
506 | 543 | bad_files=bad_files, |
507 | 544 | ) |
| 545 | + |
| 546 | + # Clean up GPU memory after async batch processing |
| 547 | + # Delete model references to completely free VRAM |
| 548 | + del model, preprocess |
| 549 | + self._cleanup_cuda_memory(device) |
| 550 | + |
| 551 | + return result |
508 | 552 |
|
509 | 553 | def _save_embeddings(self, index_result: IndexResult) -> None: |
510 | 554 | """Save embeddings to disk and clear cache.""" |
@@ -983,71 +1027,105 @@ def search_images_by_text_and_image( |
983 | 1027 | "ViT-B/32", device=device, download_root=self._clip_root() |
984 | 1028 | ) |
985 | 1029 |
|
986 | | - # Handle None queries: set weight to zero and skip embedding |
987 | | - if query_image_data is None: |
988 | | - image_weight = 0.0 |
989 | | - image_embedding = None |
990 | | - else: |
991 | | - pil_image = ImageOps.exif_transpose(query_image_data) |
992 | | - pil_image = pil_image.convert("RGB") |
993 | | - image_tensor: torch.Tensor = preprocess(pil_image) # type: ignore |
994 | | - image_tensor = image_tensor.unsqueeze(0).to(device) |
995 | | - with torch.no_grad(): |
996 | | - image_embedding = model.encode_image(image_tensor).squeeze(0) |
997 | | - |
998 | | - if not positive_query: |
999 | | - positive_weight = 0.0 |
1000 | | - pos_emb = None |
1001 | | - else: |
1002 | | - tokens = clip.tokenize([positive_query]).to(device) |
1003 | | - with torch.no_grad(): |
1004 | | - pos_emb = model.encode_text(tokens).squeeze(0) |
1005 | | - |
1006 | | - if not negative_query: |
1007 | | - negative_weight = 0.0 |
1008 | | - neg_emb = None |
1009 | | - else: |
1010 | | - tokens = clip.tokenize([negative_query]).to(device) |
1011 | | - with torch.no_grad(): |
1012 | | - neg_emb = model.encode_text(tokens).squeeze(0) |
1013 | | - |
1014 | | - # If all weights are zero, return empty result |
1015 | | - if image_weight == 0.0 and positive_weight == 0.0 and negative_weight == 0.0: |
1016 | | - return [], [] |
1017 | | - |
1018 | | - # Weighted combination: image + positive - negative |
1019 | | - combined_embedding = None |
1020 | | - if image_weight > 0.0 and image_embedding is not None: |
1021 | | - combined_embedding = image_weight * image_embedding |
1022 | | - if positive_weight > 0.0 and pos_emb is not None: |
1023 | | - if combined_embedding is None: |
1024 | | - combined_embedding = positive_weight * pos_emb |
| 1030 | + try: |
| 1031 | + # Handle None queries: set weight to zero and skip embedding |
| 1032 | + if query_image_data is None: |
| 1033 | + image_weight = 0.0 |
| 1034 | + image_embedding = None |
1025 | 1035 | else: |
1026 | | - combined_embedding += positive_weight * pos_emb |
1027 | | - if negative_weight > 0.0 and neg_emb is not None: |
1028 | | - if combined_embedding is None: |
1029 | | - combined_embedding = -negative_weight * neg_emb |
| 1036 | + pil_image = ImageOps.exif_transpose(query_image_data) |
| 1037 | + pil_image = pil_image.convert("RGB") |
| 1038 | + image_tensor: torch.Tensor = preprocess(pil_image) # type: ignore |
| 1039 | + image_tensor = image_tensor.unsqueeze(0).to(device) |
| 1040 | + with torch.no_grad(): |
| 1041 | + image_embedding = model.encode_image(image_tensor).squeeze(0) |
| 1042 | + |
| 1043 | + if not positive_query: |
| 1044 | + positive_weight = 0.0 |
| 1045 | + pos_emb = None |
1030 | 1046 | else: |
1031 | | - combined_embedding -= negative_weight * neg_emb |
1032 | | - |
1033 | | - # Normalize |
1034 | | - embeddings_tensor = torch.tensor(embeddings, dtype=torch.float32, device=device) |
1035 | | - norm_embeddings = F.normalize(embeddings_tensor, dim=-1).to(torch.float32) |
1036 | | - assert combined_embedding is not None |
1037 | | - combined_embedding_norm = F.normalize(combined_embedding, dim=-1).to( |
1038 | | - torch.float32 |
1039 | | - ) |
| 1047 | + tokens = clip.tokenize([positive_query]).to(device) |
| 1048 | + with torch.no_grad(): |
| 1049 | + pos_emb = model.encode_text(tokens).squeeze(0) |
1040 | 1050 |
|
1041 | | - # Similarity |
1042 | | - similarities = (norm_embeddings @ combined_embedding_norm).cpu().numpy() |
1043 | | - top_indices = similarities.argsort()[-top_k:][::-1] |
1044 | | - top_indices = [i for i in top_indices if similarities[i] >= minimum_score] |
1045 | | - if not top_indices: |
1046 | | - return [], [] |
| 1051 | + if not negative_query: |
| 1052 | + negative_weight = 0.0 |
| 1053 | + neg_emb = None |
| 1054 | + else: |
| 1055 | + tokens = clip.tokenize([negative_query]).to(device) |
| 1056 | + with torch.no_grad(): |
| 1057 | + neg_emb = model.encode_text(tokens).squeeze(0) |
| 1058 | + |
| 1059 | + # If all weights are zero, return empty result |
| 1060 | + if image_weight == 0.0 and positive_weight == 0.0 and negative_weight == 0.0: |
| 1061 | + return [], [] |
| 1062 | + |
| 1063 | + # Weighted combination: image + positive - negative |
| 1064 | + combined_embedding = None |
| 1065 | + if image_weight > 0.0 and image_embedding is not None: |
| 1066 | + combined_embedding = image_weight * image_embedding |
| 1067 | + if positive_weight > 0.0 and pos_emb is not None: |
| 1068 | + if combined_embedding is None: |
| 1069 | + combined_embedding = positive_weight * pos_emb |
| 1070 | + else: |
| 1071 | + combined_embedding += positive_weight * pos_emb |
| 1072 | + if negative_weight > 0.0 and neg_emb is not None: |
| 1073 | + if combined_embedding is None: |
| 1074 | + combined_embedding = -negative_weight * neg_emb |
| 1075 | + else: |
| 1076 | + combined_embedding -= negative_weight * neg_emb |
| 1077 | + |
| 1078 | + # Normalize |
| 1079 | + embeddings_tensor = torch.tensor(embeddings, dtype=torch.float32, device=device) |
| 1080 | + norm_embeddings = F.normalize(embeddings_tensor, dim=-1).to(torch.float32) |
| 1081 | + assert combined_embedding is not None |
| 1082 | + combined_embedding_norm = F.normalize(combined_embedding, dim=-1).to( |
| 1083 | + torch.float32 |
| 1084 | + ) |
1047 | 1085 |
|
1048 | | - # Translate from filename array indices to sorted filename top_indices |
1049 | | - global_indices = [int(filename_map[filenames[i]]) for i in top_indices] |
1050 | | - return global_indices, similarities[top_indices].tolist() |
| 1086 | + # Similarity |
| 1087 | + similarities = (norm_embeddings @ combined_embedding_norm).cpu().numpy() |
| 1088 | + top_indices = similarities.argsort()[-top_k:][::-1] |
| 1089 | + top_indices = [i for i in top_indices if similarities[i] >= minimum_score] |
| 1090 | + |
| 1091 | + if not top_indices: |
| 1092 | + return [], [] |
| 1093 | + |
| 1094 | + # Translate from filename array indices to sorted filename top_indices |
| 1095 | + result_indices = [int(filename_map[filenames[i]]) for i in top_indices] |
| 1096 | + result_similarities = similarities[top_indices].tolist() |
| 1097 | + |
| 1098 | + return result_indices, result_similarities |
| 1099 | + finally: |
| 1100 | + # Clean up GPU memory after search (always executed) |
| 1101 | + # Delete all GPU tensors and model references to completely free VRAM |
| 1102 | + try: |
| 1103 | + del model, preprocess |
| 1104 | + # Delete any tensors that may still be around |
| 1105 | + if 'image_tensor' in locals(): |
| 1106 | + del image_tensor |
| 1107 | + if 'tokens' in locals(): |
| 1108 | + del tokens |
| 1109 | + if 'embeddings_tensor' in locals(): |
| 1110 | + del embeddings_tensor |
| 1111 | + if 'norm_embeddings' in locals(): |
| 1112 | + del norm_embeddings |
| 1113 | + if 'combined_embedding' in locals(): |
| 1114 | + del combined_embedding |
| 1115 | + if 'combined_embedding_norm' in locals(): |
| 1116 | + del combined_embedding_norm |
| 1117 | + if 'similarities' in locals(): |
| 1118 | + del similarities |
| 1119 | + if 'image_embedding' in locals(): |
| 1120 | + del image_embedding |
| 1121 | + if 'pos_emb' in locals(): |
| 1122 | + del pos_emb |
| 1123 | + if 'neg_emb' in locals(): |
| 1124 | + del neg_emb |
| 1125 | + except (NameError, UnboundLocalError): |
| 1126 | + # Variables may not be defined if early return |
| 1127 | + pass |
| 1128 | + self._cleanup_cuda_memory(device) |
1051 | 1129 |
|
1052 | 1130 | def find_duplicate_clusters(self, similarity_threshold=0.995): |
1053 | 1131 | """ |
|
0 commit comments