Skip to content

Commit 15bb314

Browse files
Copilotlstein
andauthored
Free CUDA VRAM after search and indexing operations (#178)
* Initial plan * Add CUDA memory cleanup after search and batch operations Co-authored-by: lstein <[email protected]> * Optimize CUDA cleanup: move model to CPU instead of deleting Co-authored-by: lstein <[email protected]> * Refactor: Extract CUDA cleanup to helper method and simplify code Co-authored-by: lstein <[email protected]> * Add error handling and fix memory leak in early return path Co-authored-by: lstein <[email protected]> * Ensure CUDA cleanup with try-finally block for exception safety Co-authored-by: lstein <[email protected]> * Delete model references completely to free all CUDA VRAM Instead of moving the model to CPU (which left 230 MiB VRAM allocated), now completely delete model and tensor references. This ensures VRAM returns to zero after search and indexing operations. Changes: - Remove model parameter from _cleanup_cuda_memory() - Add torch.cuda.synchronize() before cache clearing - Explicitly delete model, preprocess, and all GPU tensors - Delete tensors in finally block with error handling Trade-off: Model must be reloaded on each search (~1-2s overhead), but VRAM is completely freed for other GPU processes. Co-authored-by: lstein <[email protected]> * Add cleanup for image_tensor, tokens, and combined_embedding tensors Ensure all GPU tensors are deleted in the finally block, including: - image_tensor (created from preprocessing) - tokens (created from tokenization) - combined_embedding (intermediate computation tensor) This ensures complete VRAM cleanup. Co-authored-by: lstein <[email protected]> * Improve CUDA cleanup with multiple cache clear passes Add more aggressive CUDA cleanup: - Call empty_cache() before and after garbage collection - Run gc.collect() twice for more thorough cleanup - Add documentation about unavoidable CUDA context baseline (~188 MiB) The 188 MiB baseline is the CUDA context initialized by PyTorch on first GPU use. This cannot be freed without ending the Python process - it's a known PyTorch/CUDA limitation. Co-authored-by: lstein <[email protected]> * Simplify CUDA cleanup sequence per code review Streamline cleanup to: synchronize -> empty_cache -> gc.collect -> empty_cache This is more efficient while still ensuring thorough cleanup. Removed redundant second gc.collect() call - a single pass is sufficient after tensor deletions. Co-authored-by: lstein <[email protected]> --------- Co-authored-by: copilot-swe-agent[bot] <[email protected]> Co-authored-by: lstein <[email protected]> Co-authored-by: Lincoln Stein <[email protected]>
1 parent 28995a5 commit 15bb314

File tree

1 file changed

+142
-64
lines changed

1 file changed

+142
-64
lines changed

photomap/backend/embeddings.py

Lines changed: 142 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import asyncio
1111
import functools
12+
import gc
1213
import logging
1314
import os
1415
import sys
@@ -256,6 +257,35 @@ def __init__(self, **data):
256257
data["embeddings_path"] = Path(data["embeddings_path"]).resolve()
257258
super().__init__(**data)
258259

260+
@staticmethod
261+
def _cleanup_cuda_memory(device: str) -> None:
262+
"""
263+
Clean up CUDA memory by clearing cache and forcing garbage collection.
264+
265+
This completely frees GPU VRAM to ensure it returns to zero (or minimal baseline)
266+
after operations. The model will need to be reloaded on subsequent operations,
267+
but this ensures GPU memory is available for other processes.
268+
269+
Note: A baseline CUDA context (~188 MiB) may remain after first GPU use.
270+
This is a PyTorch/CUDA limitation and cannot be freed without ending the process.
271+
272+
Args:
273+
device: The device string ("cuda" or "cpu")
274+
"""
275+
if device == "cuda":
276+
try:
277+
# Synchronize to ensure all CUDA operations are complete
278+
torch.cuda.synchronize()
279+
# Empty the CUDA cache
280+
torch.cuda.empty_cache()
281+
# Force garbage collection to clean up Python references
282+
gc.collect()
283+
# Empty cache again after GC to catch any newly freed memory
284+
torch.cuda.empty_cache()
285+
except RuntimeError as e:
286+
# Log but don't crash if CUDA operations fail
287+
logger.warning(f"CUDA cleanup failed: {e}")
288+
259289
def get_image_files_from_directory(
260290
self,
261291
directory: Path,
@@ -447,15 +477,22 @@ def _process_images_batch(
447477
umap_embeddings = self.create_umap_index(
448478
np.array(embeddings) if embeddings else np.empty((0, 512))
449479
)
450-
451-
return IndexResult(
480+
481+
result = IndexResult(
452482
embeddings=np.array(embeddings) if embeddings else np.empty((0, 512)),
453483
filenames=np.array(filenames),
454484
modification_times=np.array(modification_times),
455485
metadata=np.array(metadatas, dtype=object),
456486
umap_embeddings=umap_embeddings,
457487
bad_files=bad_files,
458488
)
489+
490+
# Clean up GPU memory after batch processing
491+
# Delete model references to completely free VRAM
492+
del model, preprocess
493+
self._cleanup_cuda_memory(device)
494+
495+
return result
459496

460497
async def _process_images_batch_async(
461498
self, image_paths: list[Path], album_key: str, yield_interval: int = 10
@@ -498,13 +535,20 @@ async def _process_images_batch_async(
498535
if i % yield_interval == 0:
499536
await asyncio.sleep(0.01)
500537

501-
return IndexResult(
538+
result = IndexResult(
502539
embeddings=np.array(embeddings) if embeddings else np.empty((0, 512)),
503540
filenames=np.array(filenames),
504541
modification_times=np.array(modification_times),
505542
metadata=np.array(metadatas, dtype=object),
506543
bad_files=bad_files,
507544
)
545+
546+
# Clean up GPU memory after async batch processing
547+
# Delete model references to completely free VRAM
548+
del model, preprocess
549+
self._cleanup_cuda_memory(device)
550+
551+
return result
508552

509553
def _save_embeddings(self, index_result: IndexResult) -> None:
510554
"""Save embeddings to disk and clear cache."""
@@ -983,71 +1027,105 @@ def search_images_by_text_and_image(
9831027
"ViT-B/32", device=device, download_root=self._clip_root()
9841028
)
9851029

986-
# Handle None queries: set weight to zero and skip embedding
987-
if query_image_data is None:
988-
image_weight = 0.0
989-
image_embedding = None
990-
else:
991-
pil_image = ImageOps.exif_transpose(query_image_data)
992-
pil_image = pil_image.convert("RGB")
993-
image_tensor: torch.Tensor = preprocess(pil_image) # type: ignore
994-
image_tensor = image_tensor.unsqueeze(0).to(device)
995-
with torch.no_grad():
996-
image_embedding = model.encode_image(image_tensor).squeeze(0)
997-
998-
if not positive_query:
999-
positive_weight = 0.0
1000-
pos_emb = None
1001-
else:
1002-
tokens = clip.tokenize([positive_query]).to(device)
1003-
with torch.no_grad():
1004-
pos_emb = model.encode_text(tokens).squeeze(0)
1005-
1006-
if not negative_query:
1007-
negative_weight = 0.0
1008-
neg_emb = None
1009-
else:
1010-
tokens = clip.tokenize([negative_query]).to(device)
1011-
with torch.no_grad():
1012-
neg_emb = model.encode_text(tokens).squeeze(0)
1013-
1014-
# If all weights are zero, return empty result
1015-
if image_weight == 0.0 and positive_weight == 0.0 and negative_weight == 0.0:
1016-
return [], []
1017-
1018-
# Weighted combination: image + positive - negative
1019-
combined_embedding = None
1020-
if image_weight > 0.0 and image_embedding is not None:
1021-
combined_embedding = image_weight * image_embedding
1022-
if positive_weight > 0.0 and pos_emb is not None:
1023-
if combined_embedding is None:
1024-
combined_embedding = positive_weight * pos_emb
1030+
try:
1031+
# Handle None queries: set weight to zero and skip embedding
1032+
if query_image_data is None:
1033+
image_weight = 0.0
1034+
image_embedding = None
10251035
else:
1026-
combined_embedding += positive_weight * pos_emb
1027-
if negative_weight > 0.0 and neg_emb is not None:
1028-
if combined_embedding is None:
1029-
combined_embedding = -negative_weight * neg_emb
1036+
pil_image = ImageOps.exif_transpose(query_image_data)
1037+
pil_image = pil_image.convert("RGB")
1038+
image_tensor: torch.Tensor = preprocess(pil_image) # type: ignore
1039+
image_tensor = image_tensor.unsqueeze(0).to(device)
1040+
with torch.no_grad():
1041+
image_embedding = model.encode_image(image_tensor).squeeze(0)
1042+
1043+
if not positive_query:
1044+
positive_weight = 0.0
1045+
pos_emb = None
10301046
else:
1031-
combined_embedding -= negative_weight * neg_emb
1032-
1033-
# Normalize
1034-
embeddings_tensor = torch.tensor(embeddings, dtype=torch.float32, device=device)
1035-
norm_embeddings = F.normalize(embeddings_tensor, dim=-1).to(torch.float32)
1036-
assert combined_embedding is not None
1037-
combined_embedding_norm = F.normalize(combined_embedding, dim=-1).to(
1038-
torch.float32
1039-
)
1047+
tokens = clip.tokenize([positive_query]).to(device)
1048+
with torch.no_grad():
1049+
pos_emb = model.encode_text(tokens).squeeze(0)
10401050

1041-
# Similarity
1042-
similarities = (norm_embeddings @ combined_embedding_norm).cpu().numpy()
1043-
top_indices = similarities.argsort()[-top_k:][::-1]
1044-
top_indices = [i for i in top_indices if similarities[i] >= minimum_score]
1045-
if not top_indices:
1046-
return [], []
1051+
if not negative_query:
1052+
negative_weight = 0.0
1053+
neg_emb = None
1054+
else:
1055+
tokens = clip.tokenize([negative_query]).to(device)
1056+
with torch.no_grad():
1057+
neg_emb = model.encode_text(tokens).squeeze(0)
1058+
1059+
# If all weights are zero, return empty result
1060+
if image_weight == 0.0 and positive_weight == 0.0 and negative_weight == 0.0:
1061+
return [], []
1062+
1063+
# Weighted combination: image + positive - negative
1064+
combined_embedding = None
1065+
if image_weight > 0.0 and image_embedding is not None:
1066+
combined_embedding = image_weight * image_embedding
1067+
if positive_weight > 0.0 and pos_emb is not None:
1068+
if combined_embedding is None:
1069+
combined_embedding = positive_weight * pos_emb
1070+
else:
1071+
combined_embedding += positive_weight * pos_emb
1072+
if negative_weight > 0.0 and neg_emb is not None:
1073+
if combined_embedding is None:
1074+
combined_embedding = -negative_weight * neg_emb
1075+
else:
1076+
combined_embedding -= negative_weight * neg_emb
1077+
1078+
# Normalize
1079+
embeddings_tensor = torch.tensor(embeddings, dtype=torch.float32, device=device)
1080+
norm_embeddings = F.normalize(embeddings_tensor, dim=-1).to(torch.float32)
1081+
assert combined_embedding is not None
1082+
combined_embedding_norm = F.normalize(combined_embedding, dim=-1).to(
1083+
torch.float32
1084+
)
10471085

1048-
# Translate from filename array indices to sorted filename top_indices
1049-
global_indices = [int(filename_map[filenames[i]]) for i in top_indices]
1050-
return global_indices, similarities[top_indices].tolist()
1086+
# Similarity
1087+
similarities = (norm_embeddings @ combined_embedding_norm).cpu().numpy()
1088+
top_indices = similarities.argsort()[-top_k:][::-1]
1089+
top_indices = [i for i in top_indices if similarities[i] >= minimum_score]
1090+
1091+
if not top_indices:
1092+
return [], []
1093+
1094+
# Translate from filename array indices to sorted filename top_indices
1095+
result_indices = [int(filename_map[filenames[i]]) for i in top_indices]
1096+
result_similarities = similarities[top_indices].tolist()
1097+
1098+
return result_indices, result_similarities
1099+
finally:
1100+
# Clean up GPU memory after search (always executed)
1101+
# Delete all GPU tensors and model references to completely free VRAM
1102+
try:
1103+
del model, preprocess
1104+
# Delete any tensors that may still be around
1105+
if 'image_tensor' in locals():
1106+
del image_tensor
1107+
if 'tokens' in locals():
1108+
del tokens
1109+
if 'embeddings_tensor' in locals():
1110+
del embeddings_tensor
1111+
if 'norm_embeddings' in locals():
1112+
del norm_embeddings
1113+
if 'combined_embedding' in locals():
1114+
del combined_embedding
1115+
if 'combined_embedding_norm' in locals():
1116+
del combined_embedding_norm
1117+
if 'similarities' in locals():
1118+
del similarities
1119+
if 'image_embedding' in locals():
1120+
del image_embedding
1121+
if 'pos_emb' in locals():
1122+
del pos_emb
1123+
if 'neg_emb' in locals():
1124+
del neg_emb
1125+
except (NameError, UnboundLocalError):
1126+
# Variables may not be defined if early return
1127+
pass
1128+
self._cleanup_cuda_memory(device)
10511129

10521130
def find_duplicate_clusters(self, similarity_threshold=0.995):
10531131
"""

0 commit comments

Comments
 (0)