11import os
2- import io
3- import requests
4- import numpy as np
5- from PIL import Image
62from fastapi import FastAPI , HTTPException
7- from pydantic import BaseModel
8- from sentence_transformers import SentenceTransformer
9- import torch
10-
11- # --- Configuration ---
12- MODEL_NAME = os .environ .get ("MODEL_NAME" , "clip-ViT-B-32" )
13- # TRANSFORMERS_CACHE is set via environment variable in compose.yml
14- # and defaults to /app/model_cache in the Dockerfile
15- MODEL_CACHE_DIR = os .environ .get ("TRANSFORMERS_CACHE" , "/app/model_cache" )
16-
17- # --- Initialization ---
18- # Initialize the model globally to load it once on startup
19- try :
20- # 1. Check if CUDA (GPU) is available
21- if torch .cuda .is_available ():
22- device = 'cuda'
23- print ("GPU is available. Using GPU." )
24- else :
25- device = 'cpu'
26- print ("GPU not available. Using CPU." )
27- # The model will be downloaded to MODEL_CACHE_DIR if not present
28- model = SentenceTransformer (MODEL_NAME , cache_folder = MODEL_CACHE_DIR , device = device )
29- print (f"Successfully loaded model: { MODEL_NAME } from { MODEL_CACHE_DIR } " )
30- except Exception as e :
31- print (f"Error loading model { MODEL_NAME } : { e } " )
32- # In a real service, you might want to exit or raise an error here
33- model = None
3+ from app .embedding_service import (
4+ embed_text ,
5+ embed_image ,
6+ TextEmbedRequest ,
7+ ImageEmbedRequest ,
8+ EmbeddingResponse ,
9+ default_model ,
10+ MODEL_NAME , OpenAIEmbeddingResponse , OpenAIEmbeddingRequest , open_ai_embed_image
11+ )
3412
3513app = FastAPI (
3614 title = "Multimodal Embedding Service" ,
3715 description = f"HTTP service for generating text and image embeddings using { MODEL_NAME } ." ,
3816 version = "1.0.0"
3917)
4018
41- # --- Pydantic Schemas ---
42- class TextEmbedRequest (BaseModel ):
43- texts : list [str ]
44-
45- class ImageEmbedRequest (BaseModel ):
46- image_urls : list [str ]
47-
48- class EmbeddingResponse (BaseModel ):
49- embeddings : list [list [float ]]
50- model : str
51-
52- # --- Utility Functions ---
53- def get_image_from_url (url : str ) -> Image .Image :
54- """Downloads an image from a URL and returns a PIL Image object."""
55- try :
56- response = requests .get (url , stream = True , timeout = 10 )
57- response .raise_for_status ()
58- image = Image .open (io .BytesIO (response .content ))
59- return image
60- except requests .exceptions .RequestException as e :
61- raise HTTPException (status_code = 400 , detail = f"Failed to download image from { url } : { e } " )
62- except Exception as e :
63- raise HTTPException (status_code = 400 , detail = f"Failed to process image from { url } : { e } " )
64-
6519# --- Endpoints ---
6620
6721@app .get ("/health" )
6822async def health_check ():
69- if model is None :
23+ if default_model is None :
7024 raise HTTPException (status_code = 503 , detail = "Model not loaded." )
71- return {"status" : "ok" , "model " : MODEL_NAME }
25+ return {"status" : "ok" , "default_model " : MODEL_NAME }
7226
7327@app .post ("/embed/text" , response_model = EmbeddingResponse )
74- async def embed_text (request : TextEmbedRequest ):
28+ async def embed_text_endpoint (request : TextEmbedRequest ):
7529 """Generates embeddings for a list of text strings."""
76- if model is None :
77- raise HTTPException (status_code = 503 , detail = "Model not loaded." )
78-
79- if not request .texts :
80- return EmbeddingResponse (embeddings = [], model = MODEL_NAME )
81-
82- # Encode the texts
83- embeddings = model .encode (request .texts , convert_to_numpy = True )
84-
85- # Convert numpy array to list of lists for JSON serialization
86- embeddings_list = embeddings .tolist ()
87-
88- return EmbeddingResponse (embeddings = embeddings_list , model = MODEL_NAME )
30+ try :
31+ return embed_text (request .texts )
32+ except Exception as e :
33+ raise HTTPException (status_code = 500 , detail = str (e ))
8934
9035@app .post ("/embed/image" , response_model = EmbeddingResponse )
91- async def embed_image (request : ImageEmbedRequest ):
36+ async def embed_image_endpoint (request : ImageEmbedRequest ):
9237 """Generates embeddings for a list of image URLs."""
93- if model is None :
94- raise HTTPException (status_code = 503 , detail = "Model not loaded." )
95-
96- if not request .image_urls :
97- return EmbeddingResponse (embeddings = [], model = MODEL_NAME )
38+ try :
39+ return embed_image (request .image_urls )
40+ except Exception as e :
41+ raise HTTPException (status_code = 500 , detail = str (e ))
9842
99- images = []
100- for url in request .image_urls :
101- # Download and process image
102- image = get_image_from_url (url )
103- images .append (image )
10443
105- # Encode the images
106- # The model.encode method handles both text and image inputs for multimodal models
107- embeddings = model .encode (images , convert_to_numpy = True )
108-
109- # Convert numpy array to list of lists for JSON serialization
110- embeddings_list = embeddings .tolist ()
111-
112- return EmbeddingResponse (embeddings = embeddings_list , model = MODEL_NAME )
44+ @app .post ("/v1/embeddings" , response_model = OpenAIEmbeddingResponse )
45+ async def openai_embedding_endpoint (request : OpenAIEmbeddingRequest ):
46+ """Generates embeddings for a list of image URLs."""
47+ try :
48+ return open_ai_embed_image (image_urls = request .input , model_name = request .model )
49+ except Exception as e :
50+ raise HTTPException (status_code = 500 , detail = str (e ))
0 commit comments