diff --git a/.gitignore b/.gitignore index e09d756..508a015 100644 --- a/.gitignore +++ b/.gitignore @@ -51,6 +51,9 @@ data/thumbnails/* !data/processed/.gitkeep !data/embeddings/.gitkeep !data/thumbnails/.gitkeep +data--aip/* +data--SephardicStudies/* + # Model files *.pth *.pt diff --git a/PullRequestInformation.md b/PullRequestInformation.md new file mode 100644 index 0000000..acb84d3 --- /dev/null +++ b/PullRequestInformation.md @@ -0,0 +1,130 @@ +## Description + +I made three changes, all specifically to the photograph part of the app. These were: + +1) Allowing the app to run on photographs stored in S3, without having to locally store all of the raw images. +2) Adding a date search filter. +3) Adding an option to filter photographs by file path before running the embedding search. + +## Motivation and Context + +The first of these changes allows the app to scale to larger datasets of photographs. For use cases where there are over a million photos, it will be helpful to be able to run the app without having to store all of the photos locally. + +The next two are to enable more specific photograph searching. This is particularly useful for contexts where a user might know about a specific photo they're looking for, but not know where to find it. By filtering based on date or file name they can get closer to finding the photo they want, and then layer the embedding search on top of that. + +## Type of Change + + + +- [ ] Bug fix (non-breaking change that fixes an issue) +- [x] New feature (non-breaking change that adds functionality) +- [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected) +- [x] Documentation update +- [ ] Code refactoring (no functional changes) +- [ ] Performance improvement +- [ ] Research contribution (new models, evaluation methods, etc.) +- [ ] Other (please describe): + +## Component(s) Affected + + + +- [x] Backend (Python/FastAPI) +- [x] Frontend - Photographs +- [ ] Frontend - Maps +- [ ] Frontend - Documents +- [ ] CLIP/ML models +- [x] Configuration +- [x] Documentation +- [ ] Tests +- [x] Build/deployment + +## Changes Made + + + +- Updated the generate_embeddings script to be able to download files from S3. +- Updated the generate_embeddigns script to store the origin date of a photograph into the metadata file. +- Updated the backend to fetch full photographs from S3 when they are not stored locally. +- Updated the backend to provide an API for date search. +- Updated the backend to enable file name filter on text search. +- Updated the photograph frontend to add a date search option. +- Updated the photograph frontend to include a filter bar below the text search, currently only including the file path filter. + +## Testing + +### How Has This Been Tested? + + + +I ran manual tests on each aspect that I described above. + +## Screenshots (if applicable) + + + +| Before | After | +|--------|-------| +| ![Previous text search](image.png) | ![Updated text search](image-1.png) | +| N/A | ![New date search](image-2.png) | + +## Checklist + + + +### Code Quality + +- [x] My code follows the project's coding standards +- [x] I have run `black .` and `isort .` on Python code +- [x] I have run `npm run lint` on frontend code (if applicable) +- [x] I have performed a self-review of my own code +- [x] I have commented my code, particularly in hard-to-understand areas +- [x] My changes generate no new warnings or errors + +### Testing + +- [ ] I have added tests that prove my fix is effective or that my feature works +I didn't see unit tests. +- [ ] New and existing unit tests pass locally with my changes +I didn't see unit tests. +- [x] I have tested this locally with actual data + +### Documentation + +- [x] I have updated the documentation accordingly +- [x] I have updated the README if needed +- [x] I have added docstrings to new functions/classes +- [x] I have updated `config.json` documentation if config changes were made + +### Dependencies + +- [x] I have updated `requirements.txt` (if Python dependencies changed) +- [x] I have updated `package.json` (if Node dependencies changed) +- [x] I have documented any new configuration options + +### Research (if applicable) + +- [ ] I have included references to relevant papers or research +- [ ] I have shared evaluation results or benchmarks +- [ ] I have included information about datasets used +- [ ] I have documented model training procedures + +## Breaking Changes + + + + +None / (describe breaking changes) + +## Additional Notes + + + +## Reviewers Checklist (for maintainers) + +- [ ] Code quality and style compliance +- [ ] Test coverage adequate +- [ ] Documentation complete +- [ ] No security concerns +- [ ] Performance implications acceptable +- [ ] Breaking changes documented diff --git a/README.md b/README.md index a3f2d5e..6148f20 100644 --- a/README.md +++ b/README.md @@ -12,9 +12,11 @@ This project describes out Digital Collections Explorer, available at: [https:// We present Digital Collections Explorer, a web-based, open-source exploratory search platform that leverages CLIP (Contrastive Language-Image Pre-training) for enhanced visual discovery of digital collections. Our Digital Collections Explorer can be installed locally and configured to run on a visual collection of interest on disk in just a few steps. Building upon recent advances in multimodal search techniques, our interface enables natural language queries and reverse image searches over digital collections with visual features. An overview of our system can be seen in the image above. +We are in the process of adding additional capabilities that are currently only available for photography collections. These include a configuration to run on collections stored in AWS S3 buckets, the option to limit the natural language search to sub-directories of the collection, and the option to perform a search on the original date of the photographs. + ## Features -- Multimodal search capabilities using both text and image inputs +- Multimodal search capabilities using text, image, and metadata inputs (for photographs) - Support for various digital collection types: - Historical maps - Photographs @@ -75,6 +77,12 @@ python -m src.models.clip.generate_embeddings This will process all images found in `raw_data_dir` and create embeddings in `embeddings_dir` (both set in `config.json`). +If your data is stored in an S3 bucket instead of locally, ensure your default AWS profile has read and list access to your bucket, then run the above command with the following arguments: + +```bash +python -m src.models.clip.generate_embeddings --use-remote --bucket --prefix +``` + ### Step 5: Start the Backend Server ```bash @@ -83,6 +91,8 @@ python -m src.backend.main The API server will start at http://localhost:8000 +If your data is stored in S3, change the REMOTE flag in src.backend.main to True. + ### Customizing the Frontend #### Development Mode diff --git a/image-1.png b/image-1.png new file mode 100644 index 0000000..729bdea Binary files /dev/null and b/image-1.png differ diff --git a/image-2.png b/image-2.png new file mode 100644 index 0000000..cc63740 Binary files /dev/null and b/image-2.png differ diff --git a/image.png b/image.png new file mode 100644 index 0000000..2295567 Binary files /dev/null and b/image.png differ diff --git a/package-lock.json b/package-lock.json index b103052..013fa43 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,12 +1,12 @@ { "name": "digital-collections-explorer", - "version": "0.0.1", + "version": "1.2.0", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "digital-collections-explorer", - "version": "0.0.1", + "version": "1.2.0", "dependencies": { "chalk": "^4.1.2" } diff --git a/src/backend/api/routes/images.py b/src/backend/api/routes/images.py index 594fe1a..b89ac6c 100644 --- a/src/backend/api/routes/images.py +++ b/src/backend/api/routes/images.py @@ -3,7 +3,11 @@ from fastapi import APIRouter, HTTPException, Query from fastapi.responses import FileResponse +import boto3 +import os + from src.backend.services.embedding_service import embedding_service +import src.backend.utils.helpers as helpers router = APIRouter(tags=["images"]) @@ -37,6 +41,13 @@ async def get_image_by_id( and "processed" in doc["metadata"]["paths"] ): path_str = doc["metadata"]["paths"]["processed"] + if doc["metadata"]["remote"]: + s3_client = boto3.session.Session().client("s3") + local_dir = f"{doc['metadata']['processed_dir']}/{path_str}" + helpers.download_file( + s3_client, doc["metadata"]["bucket"], path_str, local_dir + ) + path_str = local_dir else: raise HTTPException( status_code=404, detail="Image path not found in document metadata" diff --git a/src/backend/api/routes/search.py b/src/backend/api/routes/search.py index 2f6736e..2ea049c 100644 --- a/src/backend/api/routes/search.py +++ b/src/backend/api/routes/search.py @@ -1,3 +1,4 @@ +import datetime import logging from io import BytesIO @@ -7,6 +8,7 @@ from ...models.schemas import SearchResponse, SearchResult from ...services.clip_service import clip_service from ...services.embedding_service import embedding_service +from ...services.metadata_search_service import metadata_search_service logger = logging.getLogger(__name__) router = APIRouter(prefix="/api/search", tags=["search"]) @@ -17,6 +19,7 @@ async def search_by_text( query: str, limit: int = Query(30, description="Number of results per page"), page: int = Query(1, description="Page number for pagination"), + filepath_search_term: str = Query("", description="Substring to filter file paths"), ): """Search for similar content using text query.""" offset = (page - 1) * limit @@ -28,7 +31,11 @@ async def search_by_text( text_embedding = clip_service.encode_text(query) logit_scale = clip_service.model.logit_scale.exp().item() raw_results = embedding_service.search( - text_embedding, logit_scale=logit_scale, limit=limit, offset=offset + text_embedding, + logit_scale=logit_scale, + limit=limit, + offset=offset, + filepath_search_term=filepath_search_term, ) search_results = [ @@ -70,3 +77,30 @@ async def search_by_image( except Exception as e: logger.error(f"Error in image search: {str(e)}") return SearchResponse(results=[]) + + +@router.get("/date", response_model=SearchResponse) +async def search_by_date( + query: datetime.date, + limit: int = Query(30, description="Number of results per page"), + page: int = Query(1, description="Page number for pagination"), + searchNearDate: bool = Query( + False, description="Whether to search for dates near the target date" + ), +): + """Search for similar content using date query.""" + offset = (page - 1) * limit + + try: + raw_results = metadata_search_service.date_search( + query, limit=limit, offset=offset, search_near_date=searchNearDate + ) + + search_results = [ + SearchResult(id=result["id"], score=1, metadata=result["metadata"]) + for result in raw_results + ] + return SearchResponse(results=search_results) + except Exception as e: + logger.error(f"Error in date search: {str(e)}") + return SearchResponse(results=[]) diff --git a/src/backend/main.py b/src/backend/main.py index 7826ab4..aa7778a 100644 --- a/src/backend/main.py +++ b/src/backend/main.py @@ -11,12 +11,16 @@ from .core.config import settings from .services.embedding_service import embedding_service +import os + logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", ) logger = logging.getLogger(__name__) +REMOTE_FILES = False + @asynccontextmanager async def lifespan(app): @@ -29,6 +33,10 @@ async def lifespan(app): yield + # Clean up cached files downloaded from S3 + if REMOTE_FILES: + os.system(f"rm -rf {str(settings.processed_data_dir)}/*") + app = FastAPI( title=settings.api_title, diff --git a/src/backend/services/embedding_service.py b/src/backend/services/embedding_service.py index 104786a..87d5767 100644 --- a/src/backend/services/embedding_service.py +++ b/src/backend/services/embedding_service.py @@ -5,6 +5,7 @@ from typing import Any, Dict, List, Optional import torch +import numpy as np from ..core.config import settings @@ -99,16 +100,45 @@ def get_document_by_id(self, doc_id: str) -> Optional[Dict[str, Any]]: return None + def filepath_filter(self, filepath_substring: str) -> torch.Tensor: + """Create a metadata filter for file path substring matching""" + metadata_arr = np.array( + [ + self.metadata[item_id].get("paths", {}).get("original", "") + for item_id in self.item_ids + ] + ) + matching_indices = np.where( + np.char.find(metadata_arr.astype(str), filepath_substring) != -1 + )[0] + return torch.tensor(matching_indices, dtype=torch.long) + def search( self, query_embedding: torch.Tensor, logit_scale: Optional[float] = None, limit: int = 20, offset: int = 0, + filepath_search_term: str = "", ) -> List[Dict[str, Any]]: """Search for similar items using query embedding with pagination""" try: - similarities = torch.matmul(self.embeddings, query_embedding.t()).squeeze() + if filepath_search_term != "": + valid_indices = self.filepath_filter(filepath_search_term) + + if len(valid_indices) == 0: + return [] # No items match the filter + + # Filter embeddings to only valid ones + filtered_embeddings = self.embeddings[valid_indices] + similarities = torch.matmul( + filtered_embeddings, query_embedding.t() + ).squeeze() + else: + valid_indices = None + similarities = torch.matmul( + self.embeddings, query_embedding.t() + ).squeeze() if logit_scale is not None: similarities = similarities * logit_scale @@ -127,7 +157,14 @@ def search( for idx, score in zip( paginated_indices.tolist(), paginated_scores.tolist() ): - idx_int = int(idx) + # Map back to original index if we filtered + if valid_indices is not None: + original_idx = valid_indices[idx].item() + else: + original_idx = idx + + idx_int = int(original_idx) + if idx_int >= len(self.item_ids): logger.warning( f"Index {idx_int} out of range for item_ids of length {len(self.item_ids)}" diff --git a/src/backend/services/metadata_search_service.py b/src/backend/services/metadata_search_service.py new file mode 100644 index 0000000..8a6527b --- /dev/null +++ b/src/backend/services/metadata_search_service.py @@ -0,0 +1,71 @@ +from datetime import datetime +import json +import logging +from pathlib import Path + +from ..core.config import settings + +logger = logging.getLogger(__name__) + + +class MetadataSearchService: + def __init__(self): + self.embeddings_dir = Path(settings.embeddings_dir) + self.embeddings = None + self.item_ids = None + self.metadata = None + self.is_loaded = False + + def load_metadata(self) -> None: + """Load metadata from the embeddings directory""" + if self.is_loaded: + logger.info("Metadata already loaded") + return + + try: + metadata_path = self.embeddings_dir / "metadata.json" + + logger.info(f"Looking for metadata at {metadata_path}") + + if metadata_path.exists(): + with open(metadata_path, "r") as f: + self.metadata = json.load(f) + logger.info(f"Loaded metadata for {len(self.metadata)} items") + else: + self.metadata = {} + logger.warning("Metadata file not found, proceeding without metadata.") + + except Exception as e: + logger.error(f"Error loading metadata: {str(e)}") + raise + + def date_search(self, target_date, limit=30, offset=0, search_near_date=False): + """Search for items matching the target date""" + if not self.is_loaded: + self.load_metadata() + + results = [] + for item_id, data in self.metadata.items(): + item_date_str = data.get("date") + if item_date_str: + try: + item_date = datetime.strptime( + item_date_str, "%Y-%m-%d %H:%M:%S" + ).date() + if search_near_date: + delta = abs((item_date - target_date).days) + if delta <= 30: # within a month + results.append({"id": item_id, "metadata": data}) + elif item_date == target_date: + results.append({"id": item_id, "metadata": data}) + except ValueError: + logger.error( + f"Failed to parse date for item {item_id}: {item_date_str}" + ) + + # Apply offset and limit + results = results[offset : offset + limit] + return results + + +metadata_search_service = MetadataSearchService() diff --git a/src/backend/utils/helpers.py b/src/backend/utils/helpers.py index dcfcb5f..54022bd 100644 --- a/src/backend/utils/helpers.py +++ b/src/backend/utils/helpers.py @@ -1,6 +1,7 @@ import json import os from pathlib import Path +from typing import Any def load_config(config_path=None): @@ -25,3 +26,15 @@ def load_config(config_path=None): config[key] = str(root_dir / config[key]) return config + + +def download_file(client: Any, bucket: str, filename: str, destination: str): + # Get the directory path (everything except the filename) + directory = os.path.dirname(destination) + + # Create the directory structure if it doesn't exist + # exist_ok=True means no error if directory already exists + if directory: # Only create if there's actually a directory path + os.makedirs(directory, exist_ok=True) + + client.download_file(bucket, filename, destination) diff --git a/src/frontend/photographs/src/App.jsx b/src/frontend/photographs/src/App.jsx index dd1885a..5e96abb 100644 --- a/src/frontend/photographs/src/App.jsx +++ b/src/frontend/photographs/src/App.jsx @@ -2,20 +2,22 @@ import React, { useState, useEffect, useCallback, useRef } from 'react'; import SearchBar from './components/SearchBar'; import { ResultsPerPageDropdown } from './components/Pagination'; import SearchResults from './components/SearchResults'; -import { searchByText, searchByImage, getEmbeddingStats } from './services/api'; +import { searchByText, searchByImage, getEmbeddingStats, searchByDate } from './services/api'; import './App.css'; function App() { const [photos, setPhotos] = useState([]); const [searchQuery, setSearchQuery] = useState(''); + const [filepathSearchTerm, setFilepathSearchTerm] = useState(''); const [isLoading, setIsLoading] = useState(false); const [error, setError] = useState(null); const searchInputRef = useRef(null); - const [searchMode, setSearchMode] = useState('text'); // 'text' or 'image' + const [searchMode, setSearchMode] = useState('text'); // 'text', 'image', 'date' const [uploadedImage, setUploadedImage] = useState(null); const [hasMore, setHasMore] = useState(false); const [currentPage, setCurrentPage] = useState(1); const [resultsPerPage, setResultsPerPage] = useState(50); + const [searchNearDate, setSearchNearDate] = useState(false); const [embeddingCount, setEmbeddingCount] = useState(null); useEffect(() => { @@ -51,7 +53,7 @@ function App() { })); }; - const handleSearchByText = useCallback(async (query) => { + const handleSearchByText = useCallback(async (query, filepathSearchTerm) => { if (!query.trim()) { setError('Please enter a search term'); return; @@ -61,9 +63,10 @@ function App() { setError(null); setSearchMode('text'); setSearchQuery(query); + setFilepathSearchTerm(filepathSearchTerm); try { - const results = await searchByText(query, resultsPerPage, currentPage); + const results = await searchByText(query, resultsPerPage, currentPage, filepathSearchTerm); setPhotos(formatPhotosForGallery(results)); setHasMore(results.length >= resultsPerPage); } catch (error) { @@ -91,13 +94,39 @@ function App() { } }, [resultsPerPage, currentPage]); + const handleSearchByDate = useCallback(async (query, searchNearDate) => { + if (!query.trim()) { + setError('Please enter a search term'); + return; + } + + setIsLoading(true); + setError(null); + setSearchMode('date'); + setSearchQuery(query); + setSearchNearDate(searchNearDate); + + try { + const results = await searchByDate(query, resultsPerPage, currentPage, searchNearDate); + setPhotos(formatPhotosForGallery(results)); + setHasMore(results.length >= resultsPerPage); + } catch (error) { + console.error('Error performing search:', error); + setError('Date search failed. Please try again.'); + } finally { + setIsLoading(false); + } +}, [resultsPerPage, currentPage]); + useEffect(() => { if (searchMode === 'text' && searchQuery.trim()) { - handleSearchByText(searchQuery); + handleSearchByText(searchQuery, filepathSearchTerm); } else if (searchMode === 'image' && uploadedImage) { handleSearchByImage(uploadedImage); + } else if (searchMode === 'date' && searchQuery.trim()) { + handleSearchByDate(searchQuery, searchNearDate); } - }, [currentPage, resultsPerPage, searchMode, searchQuery, uploadedImage, handleSearchByText, handleSearchByImage]); + }, [currentPage, resultsPerPage, searchMode, searchQuery, searchNearDate, uploadedImage, filepathSearchTerm, handleSearchByText, handleSearchByImage, handleSearchByDate]); const handleSearchModeChanged = useCallback((mode) => { setSearchMode(mode); @@ -121,10 +150,15 @@ function App() { setSearchMode={handleSearchModeChanged} searchQuery={searchQuery} setSearchQuery={setSearchQuery} + searchNearDate={searchNearDate} + setSearchNearDate={setSearchNearDate} uploadedImage={uploadedImage} setUploadedImage={setUploadedImage} + filepathSearchTerm={filepathSearchTerm} + setFilepathSearchTerm={setFilepathSearchTerm} onSearchByText={handleSearchByText} onSearchByImage={handleSearchByImage} + onSearchByDate={handleSearchByDate} /> + <> +
Limit to file paths containing:
+ setFilepathSearchTerm(e.target.value)} + /> + + + ) +} + +export default FilterBar; \ No newline at end of file diff --git a/src/frontend/photographs/src/components/SearchBar.css b/src/frontend/photographs/src/components/SearchBar.css index 85b75c3..a5c9fd6 100644 --- a/src/frontend/photographs/src/components/SearchBar.css +++ b/src/frontend/photographs/src/components/SearchBar.css @@ -19,6 +19,7 @@ transition: border-color 0.3s; background-color: #2c2c2c; color: #f8f5f0; + width: 90%; } .search-input::placeholder { @@ -193,6 +194,33 @@ color: #fff; } +.date-search-block { + display: flex; + flex-direction: column; + width: 100%; + align-items: flex-start; + justify-content: center; + gap: 12px; +} + +.near-date-checkbox-container { + display: flex; + flex-direction: row; + align-items: flex-start; + justify-content: flex-start; +} + +.include-near-date { + display: flex; + align-items: center; + justify-content: center; + accent-color: #8b5a2b; + font-size: 2rem; + margin: 5; + width: 20px; + height: 20px; +} + @media (max-width: 600px) { .search-bar form { flex-direction: column; @@ -203,3 +231,10 @@ padding: 12px; } } + +.vertical-stack { + display: flex; + flex-direction: column; + gap: 10px; + width: 100%; +} diff --git a/src/frontend/photographs/src/components/SearchBar.jsx b/src/frontend/photographs/src/components/SearchBar.jsx index 02ea3d4..c6ac32b 100644 --- a/src/frontend/photographs/src/components/SearchBar.jsx +++ b/src/frontend/photographs/src/components/SearchBar.jsx @@ -1,5 +1,6 @@ import React, { useState } from 'react'; import './SearchBar.css'; +import FilterBar from './FilterBar'; function SearchBar({ inputRef, @@ -9,8 +10,13 @@ function SearchBar({ setSearchQuery, uploadedImage, setUploadedImage, + filepathSearchTerm, + setFilepathSearchTerm, onSearchByText, onSearchByImage, + onSearchByDate, + searchNearDate, + setSearchNearDate, }) { const [previewUrl, setPreviewUrl] = useState(null); @@ -18,9 +24,11 @@ function SearchBar({ e.preventDefault(); if (searchMode === 'text') { - onSearchByText(searchQuery); + onSearchByText(searchQuery, filepathSearchTerm); } else if (searchMode === 'image' && uploadedImage) { onSearchByImage(uploadedImage); + } else if (searchMode === 'date') { + onSearchByDate(searchQuery, searchNearDate); } }; @@ -48,13 +56,14 @@ function SearchBar({ }; const switchMode = (mode) => { + if (mode === searchMode) return; setSearchMode(mode); - if (mode === 'text') { + if (mode !== 'image') { clearImage(); - } else { - setSearchQuery(''); - } + } + setSearchQuery(''); + }; return ( @@ -74,26 +83,40 @@ function SearchBar({ > Image Search +
{searchMode === 'text' ? ( - <> - setSearchQuery(e.target.value)} - placeholder="Search historical photographs..." - className="search-input" - aria-label="Search photographs" +
+
+ setSearchQuery(e.target.value)} + placeholder="Search historical photographs..." + className="search-input" + aria-label="Search photographs" + /> + +
+ - - - ) : ( +
+ + ) : (searchMode === 'image' ? (
{!previewUrl ? (
@@ -130,7 +153,36 @@ function SearchBar({ Find Similar
- )} + ) : ( + <> +
+ setSearchQuery(e.target.value)} + placeholder="Search historical photographs..." + className="search-input" + aria-label="Search photographs" + /> +
+ setSearchNearDate(e.target.checked)} + id="near-date-checkbox" + className="include-near-date" + aria-label="Include photographs taken near this date" + /> + +
+
+ + + ))} {searchMode === 'text' && ( diff --git a/src/frontend/photographs/src/services/api.js b/src/frontend/photographs/src/services/api.js index 388be29..f170a77 100644 --- a/src/frontend/photographs/src/services/api.js +++ b/src/frontend/photographs/src/services/api.js @@ -5,12 +5,13 @@ const API_URL = import.meta.env.API_BASE_URL; * @param {string} query - The text query * @param {number} limit - Maximum number of results to return (default: 50) * @param {number} page - Page number for pagination (default: 1) + * @param {string} filepathSearchTerm - Substring to filter file paths (default: '') * @returns {Promise} - Array of search results */ -export const searchByText = async (query, limit = 50, page = 1) => { +export const searchByText = async (query, limit = 50, page = 1, filepathSearchTerm = '') => { try { const pageParam = Math.max(1, parseInt(page) || 1); - const response = await fetch(`${API_URL}/api/search/text?query=${encodeURIComponent(query)}&limit=${limit}&page=${pageParam}`); + const response = await fetch(`${API_URL}/api/search/text?query=${encodeURIComponent(query)}&limit=${limit}&page=${pageParam}&filepath_search_term=${encodeURIComponent(filepathSearchTerm)}`); if (!response.ok) { throw new Error(`API error: ${response.status}`); @@ -74,3 +75,28 @@ export const getEmbeddingStats = async () => { throw error; } }; + +/** + * Search for similar photographs by date + * @param {string} query - The date query + * @param {number} limit - Maximum number of results to return (default: 50) + * @param {number} page - Page number for pagination (default: 1) + * @returns {Promise} - Array of search results + */ +export const searchByDate = async (query, limit = 50, page = 1, searchNearDate = false) => { + try { + const pageParam = Math.max(1, parseInt(page) || 1); + console.log('searchByDate called with searchNearDate:', searchNearDate); + const response = await fetch(`${API_URL}/api/search/date?query=${encodeURIComponent(query)}&limit=${limit}&page=${pageParam}&searchNearDate=${searchNearDate}`); + + if (!response.ok) { + throw new Error(`API error: ${response.status}`); + } + + const { results } = await response.json(); + return results; + } catch (error) { + console.error('Error searching photos:', error); + throw error; + } +}; \ No newline at end of file diff --git a/src/models/clip/generate_embeddings.py b/src/models/clip/generate_embeddings.py index 28a5711..f04ec01 100644 --- a/src/models/clip/generate_embeddings.py +++ b/src/models/clip/generate_embeddings.py @@ -11,9 +11,19 @@ import torch from pdf2image import convert_from_path from PIL import Image +from dataclasses import dataclass +from typing import Dict, Any +import PyPDF2 +import base64 +import argparse +from PIL import Image +from datetime import datetime + + from transformers import CLIPModel, CLIPProcessor from src.backend.core.config import settings +import src.models.clip.util as util logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" @@ -63,6 +73,9 @@ def process_pdf( processed_dir, thumbnails_dir, timing_info: Dict[str, float], + use_remote=False, + bucket="", + file="", ): """Process a PDF file and generate embeddings for each page""" try: @@ -86,6 +99,13 @@ def process_pdf( with open(file_path, "rb") as f: pdf = PyPDF2.PdfReader(f) n_pages = len(pdf.pages) + metadata = pdf.metadata + if metadata and "/CreationDate" in metadata: + pdf_date = datetime.strptime( + metadata["/CreationDate"], "%Y:%m:%d %H:%M:%S" + ) + else: + pdf_date = None except Exception as e: logger.error(f"Error reading PDF metadata: {file_path}, error: {e}") return None @@ -105,10 +125,11 @@ def process_pdf( thumbnail_path = pdf_thumbnails_dir / f"{i}.jpg" thumbnail.save(thumbnail_path, "JPEG", quality=80) - processed_image = image.copy() - processed_image.thumbnail((1920, 1920), Image.Resampling.LANCZOS) - processed_image_path = pdf_processed_dir / f"{i}.jpg" - processed_image.save(processed_image_path, "JPEG", quality=90) + if not use_remote: + processed_image = image.copy() + processed_image.thumbnail((1920, 1920), Image.Resampling.LANCZOS) + processed_image_path = pdf_processed_dir / f"{i}.jpg" + processed_image.save(processed_image_path, "JPEG", quality=90) page_embedding = generate_embeddings( model, processor, [image], device, timing_info @@ -142,11 +163,19 @@ def process_pdf( "type": "pdf_page", "page": i, "n_pages": n_pages, + "date": pdf_date.strftime("%Y-%m-%d %H:%M:%S") if pdf_date else None, "paths": { "original": str(file_path), - "processed": str(pdf_processed_dir / f"{i}.jpg"), + "processed": ( + str(pdf_processed_dir / f"{i}.jpg") + if not use_remote + else f"{file}" + ), "thumbnail": str(pdf_thumbnails_dir / f"{i}.jpg"), }, + "remote": use_remote, + "bucket": bucket, + "processed_dir": str(processed_dir), } results.append((item_id, embedding, metadata)) @@ -165,6 +194,9 @@ def process_image( processed_dir, thumbnails_dir, timing_info: Dict[str, float], + use_remote=False, + bucket="", + file="", ): """Process an image file and generate its embedding""" try: @@ -176,13 +208,14 @@ def process_image( thumbnail_path = thumbnails_dir / f"{file_path.stem}.jpg" thumbnail.save(thumbnail_path, "JPEG", quality=80) - processed_image = image.copy() - processed_image.thumbnail((1920, 1920), Image.Resampling.LANCZOS) + if not use_remote: + processed_image = image.copy() + processed_image.thumbnail((1920, 1920), Image.Resampling.LANCZOS) - image_processed_dir = processed_dir / file_path.stem - image_processed_dir.mkdir(parents=True, exist_ok=True) - processed_image_path = image_processed_dir / "0.jpg" - processed_image.save(processed_image_path, "JPEG", quality=90) + image_processed_dir = processed_dir / file_path.stem + image_processed_dir.mkdir(parents=True, exist_ok=True) + processed_image_path = image_processed_dir / "0.jpg" + processed_image.save(processed_image_path, "JPEG", quality=90) embedding = generate_embeddings(model, processor, [image], device, timing_info) if embedding is None: @@ -194,15 +227,39 @@ def process_image( .decode("utf-8") .rstrip("=") ) + image_date = None + + with Image.open(file_path) as img: + img.verify() # Verify that it is, in fact, an image + exif_data = img._getexif() + if exif_data: + # Exif tag 36867 corresponds to 'DateTimeOriginal' + datetime_original_tag = 306 + if datetime_original_tag in exif_data: + datetime_str = exif_data[datetime_original_tag] + # Convert the string to a datetime object + try: + image_date = datetime.strptime( + datetime_str, "%Y:%m:%d %H:%M:%S" + ) + except ValueError: + logger.error( + f"Failed to parse date from {image}: {datetime_str}" + ) + image_date = None metadata = { "file_name": file_path.name, "type": "image", + "date": image_date.strftime("%Y-%m-%d %H:%M:%S") if image_date else None, "paths": { "original": str(file_path), - "processed": str(processed_image_path), + "processed": str(processed_image_path) if not use_remote else f"{file}", "thumbnail": str(thumbnail_path), }, + "remote": use_remote, + "bucket": bucket, + "processed_dir": str(processed_dir), } return [(item_id, embedding[0], metadata)] @@ -296,9 +353,133 @@ def process_files( ) +def process_remote_files( + model: Any, + processor: Any, + device: str, + raw_data_dir: Path, + processed_dir: Path, + thumbnails_dir: Path, + bucket: str, + prefix: str, + timing_info: Dict[str, float], +) -> ProcessingResult: + """Process all files in the raw data directory""" + logger.info(f"Looking for files in {bucket}/{prefix}") + + client = util.create_client() + + files = util.get_file_names(client, bucket, prefix)[1:] + file_paths = [Path(file) for file in files] + + supported_extensions = { + "pdf": ["pdf"], + "image": ["jpg", "jpeg", "png", "gif", "bmp", "tiff", "webp"], + } + + files_to_process = [] + skipped_items_count = 0 + failed_items_count = 0 + + for file_path in file_paths: + ext = file_path.suffix.lower().lstrip(".") + is_image = ext in supported_extensions["image"] + is_pdf = ext in supported_extensions["pdf"] + + if is_image or is_pdf and settings.collection_type != "photographs": + files_to_process.append(file_path) + else: + skipped_items_count += 1 + + logger.info(f"Found {len(files_to_process)} eligible files to process.") + + if not files_to_process: + logger.warning(f"No files found in {bucket}") + return ProcessingResult({}, {}, skipped_items_count, failed_items_count) + + embeddings = {} + metadata = {} + + for i, file_path in enumerate(file_paths): + try: + ext = file_path.suffix.lower().lstrip(".") + is_image = ext in supported_extensions["image"] + is_pdf = ext in supported_extensions["pdf"] + (raw_data_dir / prefix).mkdir(parents=True, exist_ok=True) + local_dir = f"{raw_data_dir}/{files[i]}" + util.download_file(client, bucket, files[i], local_dir) + if is_pdf: + results = process_pdf( + local_dir, + raw_data_dir, + model, + processor, + device, + processed_dir, + thumbnails_dir, + timing_info, + True, + bucket, + files[i], + ) + elif is_image: + results = process_image( + local_dir, + raw_data_dir, + model, + processor, + device, + processed_dir, + thumbnails_dir, + timing_info, + True, + bucket, + files[i], + ) + Path.unlink(raw_data_dir / file_path) + + if results: + for item_id, embedding, metadata_item in results: + embeddings[item_id] = embedding + metadata[item_id] = metadata_item + else: + failed_items_count += 1 + except Exception as e: + logger.error(f"Error processing {file_path}: {e}") + failed_items_count += 1 + + return ProcessingResult( + embeddings, metadata, skipped_items_count, failed_items_count + ) + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--use_remote", + action="store_true", + help="Whether to use a remote set of inputs.", + ) + parser.add_argument( + "--bucket", + type=str, + default="", + help="The name of the bucket where the images are stored", + ) + parser.add_argument( + "--prefix", type=str, default="", help="The prefix where the images are stored" + ) + + args = parser.parse_args() + return args.use_remote, args.bucket, args.prefix + + def main(): start_time = time.time() + # Get command line arguments + USE_REMOTE, BUCKET, PREFIX = parse_args() + RAW_DATA_DIR = Path(settings.raw_data_dir) EMBEDDINGS_DIR = Path(settings.embeddings_dir) PROCESSED_DIR = Path(settings.processed_data_dir) @@ -324,15 +505,28 @@ def main(): embedding_timing_info = {"total_duration": 0.0} - result = process_files( - model, - processor, - DEVICE, - RAW_DATA_DIR, - PROCESSED_DIR, - THUMBNAILS_DIR, - embedding_timing_info, - ) + if USE_REMOTE: + result = process_remote_files( + model, + processor, + DEVICE, + RAW_DATA_DIR, + PROCESSED_DIR, + THUMBNAILS_DIR, + BUCKET, + PREFIX, + embedding_timing_info, + ) + else: + result = process_files( + model, + processor, + DEVICE, + RAW_DATA_DIR, + PROCESSED_DIR, + THUMBNAILS_DIR, + embedding_timing_info, + ) embeddings_file = EMBEDDINGS_DIR / "embeddings.pt" item_ids_file = EMBEDDINGS_DIR / "item_ids.pt" diff --git a/src/models/clip/util.py b/src/models/clip/util.py new file mode 100644 index 0000000..68bbc7d --- /dev/null +++ b/src/models/clip/util.py @@ -0,0 +1,34 @@ +import boto3 +from typing import Any +import os + + +def create_client(): + client = boto3.session.Session().client("s3") + return client + +def get_file_names( + client: Any, + bucket: str, + prefix: str): + + paginator = client.get_paginator('list_objects_v2') + pages = paginator.paginate(Bucket=bucket, Prefix=prefix, Delimiter='/') + filenames = [] + for page in pages: + for obj in page['Contents']: + if not obj["Key"].endswith("/"): + filenames.append(obj["Key"]) + return filenames + + +def download_file(client: Any, bucket: str, filename: str, destination: str): + # Get the directory path (everything except the filename) + directory = os.path.dirname(destination) + + # Create the directory structure if it doesn't exist + # exist_ok=True means no error if directory already exists + if directory: # Only create if there's actually a directory path + os.makedirs(directory, exist_ok=True) + + client.download_file(bucket, filename, destination)