Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 126 additions & 39 deletions backend/api/endpoints/images.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import tempfile
import os
import uuid
from pathlib import Path

from backend.models.images import (
ImageGenerationRequest,
Expand Down Expand Up @@ -47,6 +48,83 @@
logger = logging.getLogger(__name__)


def normalize_filename(filename: str) -> str:
"""
Normalize a filename to be safe for file systems.

Args:
filename: The filename to normalize

Returns:
A normalized filename safe for most file systems
"""
if not filename:
return filename

# Use pathlib to handle the filename safely
path = Path(filename)

# Get the stem (filename without extension) and suffix (extension)
stem = path.stem
suffix = path.suffix

# Remove or replace invalid characters for most filesystems
# Keep alphanumeric, hyphens, underscores, and dots
stem = re.sub(r'[^a-zA-Z0-9_\-.]', '_', stem)

# Remove multiple consecutive underscores
stem = re.sub(r'_+', '_', stem)

# Remove leading/trailing underscores and dots
stem = stem.strip('_.')

# Ensure the filename isn't empty
if not stem:
stem = "generated_image"

# Reconstruct the filename
normalized = f"{stem}{suffix}" if suffix else stem

# Ensure the filename isn't too long (most filesystems support 255 chars)
if len(normalized) > 200: # Leave some room for additional suffixes
# Truncate the stem but keep the extension
max_stem_length = 200 - len(suffix)
stem = stem[:max_stem_length]
normalized = f"{stem}{suffix}" if suffix else stem

return normalized


async def generate_filename_for_prompt(prompt: str, extension: str = None) -> str:
"""
Generate a filename using the existing filename generation endpoint.

Args:
prompt: The prompt used for image generation
extension: File extension (e.g., '.png', '.jpg')

Returns:
Generated filename or None if generation fails
"""
try:
# Create request for filename generation
filename_request = ImageFilenameGenerateRequest(
prompt=prompt,
extension=extension
)

# Call the filename generation function directly
filename_response = generate_image_filename(filename_request)

# Normalize the generated filename
generated_filename = normalize_filename(filename_response.filename)

return generated_filename

except Exception as e:
return None


@router.post("/generate", response_model=ImageGenerationResponse)
async def generate_image(request: ImageGenerationRequest):
"""Generate an image based on the provided prompt and settings"""
Expand Down Expand Up @@ -74,9 +152,6 @@ async def generate_image(request: ImageGenerationRequest):
if request.user:
params["user"] = request.user

logger.info(
f"Generating image with gpt-image-1, quality: {request.quality}, size: {request.size}")

# Generate image
response = dalle_client.generate_image(**params)

Expand All @@ -99,10 +174,6 @@ async def generate_image(request: ImageGenerationRequest):
input_tokens_details=input_tokens_details
)

# Log token usage for cost tracking
logger.info(
f"Token usage - Total: {token_usage.total_tokens}, Input: {token_usage.input_tokens}, Output: {token_usage.output_tokens}")

return ImageGenerationResponse(
success=True,
message="Refer to the imgen_model_response for details",
Expand Down Expand Up @@ -145,19 +216,12 @@ async def edit_image(request: ImageEditRequest):
if request.user:
params["user"] = request.user

# Log information about multiple images if applicable
# Check if organization is verified when using multiple images
if isinstance(request.image, list):
image_count = len(request.image)
logger.info(
f"Editing with {image_count} reference images using gpt-image-1, quality: {request.quality}, size: {request.size}")

# Check if organization is verified when using multiple images
if image_count > 1 and not settings.OPENAI_ORG_VERIFIED:
logger.warning(
"Using multiple reference images requires organization verification")
else:
logger.info(
f"Editing single image using gpt-image-1, quality: {request.quality}, size: {request.size}")

# Perform image editing
response = dalle_client.edit_image(**params)
Expand Down Expand Up @@ -209,10 +273,6 @@ async def edit_image_upload(
):
"""Edit input images uploaded via multipart form data"""
try:
# Log request info
logger.info(
f"Received {len(image)} image(s) for editing with prompt: {prompt}")

# Validate file size for all images
max_file_size_mb = settings.GPT_IMAGE_MAX_FILE_SIZE_MB
temp_files = []
Expand Down Expand Up @@ -477,10 +537,32 @@ async def save_generated_images(
# Reset file pointer
img_file.seek(0)

# Create filename
quality_suffix = f"_{request.quality}" if request.model == "gpt-image-1" and hasattr(
request, "quality") else ""
filename = f"generated_image_{idx+1}{quality_suffix}.{img_format.lower()}"
# Generate intelligent filename using the existing endpoint
if request.prompt:
filename = await generate_filename_for_prompt(
request.prompt,
f".{img_format.lower()}"
)

# Add index suffix for multiple images
if filename and len(images_data) > 1:
# Insert index before the extension
path = Path(filename)
stem = path.stem
suffix = path.suffix
filename = f"{stem}_{idx+1}{suffix}"
logger.info(
f"Using generated filename with index: {filename}")
elif filename:
logger.info(f"Using generated filename: {filename}")

# Fallback to default naming if filename generation fails
if not filename:
quality_suffix = f"_{request.quality}" if request.model == "gpt-image-1" and hasattr(
request, "quality") else ""
filename = f"generated_image_{idx+1}{quality_suffix}.{img_format.lower()}"
filename = normalize_filename(filename)
logger.info(f"Using fallback filename: {filename}")

elif "url" in img_data:
# Download image from URL
Expand All @@ -507,10 +589,27 @@ async def save_generated_images(
# Reset file pointer
img_file.seek(0)

# Create filename
quality_suffix = f"_{request.quality}" if request.model == "gpt-image-1" and hasattr(
request, "quality") else ""
filename = f"generated_image_{idx+1}{quality_suffix}.{ext}"
# Generate intelligent filename using the existing endpoint
if request.prompt:
filename = await generate_filename_for_prompt(
request.prompt,
f".{ext}"
)

# Add index suffix for multiple images
if filename and len(images_data) > 1:
# Insert index before the extension
path = Path(filename)
stem = path.stem
suffix = path.suffix
filename = f"{stem}_{idx+1}{suffix}"

# Fallback to default naming if filename generation fails
if not filename:
quality_suffix = f"_{request.quality}" if request.model == "gpt-image-1" and hasattr(
request, "quality") else ""
filename = f"generated_image_{idx+1}{quality_suffix}.{ext}"
filename = normalize_filename(filename)
else:
logger.warning(
f"Unsupported image data format for image {idx+1}")
Expand Down Expand Up @@ -627,7 +726,6 @@ def analyze_image(req: ImageAnalyzeRequest):
file_path += f"?{image_sas_token}"

# Download the image from the URL
logger.info(f"Downloading image for analysis from: {file_path}")
response = requests.get(file_path, timeout=30)
if response.status_code != 200:
raise HTTPException(
Expand All @@ -640,7 +738,6 @@ def analyze_image(req: ImageAnalyzeRequest):

# Option 2: Process from base64 string
elif req.base64_image:
logger.info("Processing image from base64 data")
try:
# Decode base64 to binary
image_content = base64.b64decode(req.base64_image)
Expand All @@ -658,8 +755,6 @@ def analyze_image(req: ImageAnalyzeRequest):
has_transparency = img.mode == 'RGBA' and 'A' in img.getbands()

if has_transparency:
logger.info(
"Image has transparency, converting for analysis")
# Create a white background
background = Image.new(
'RGBA', img.size, (255, 255, 255, 255))
Expand All @@ -678,8 +773,6 @@ def analyze_image(req: ImageAnalyzeRequest):
# This is optional but can help with very large images
width, height = img.size
if width > 1500 or height > 1500:
logger.info(
f"Image is large ({width}x{height}), resizing for analysis")
# Calculate new dimensions
max_dimension = 1500
if width > height:
Expand Down Expand Up @@ -713,7 +806,6 @@ def analyze_image(req: ImageAnalyzeRequest):
image_base64 = re.sub(r"^data:image/.+;base64,", "", image_base64)

# analyze the image using the LLM
logger.info("Sending image to LLM for analysis")
image_analyzer = ImageAnalyzer(llm_client, settings.LLM_DEPLOYMENT)
insights = image_analyzer.image_chat(
image_base64, analyze_image_system_message)
Expand Down Expand Up @@ -774,17 +866,12 @@ def protect_image_prompt(req: ImagePromptBrandProtectionRequest):
try:
if req.brands_to_protect:
if req.protection_mode == "replace":
logger.info(
f"Replace competitor brands of: {req.brands_to_protect}")
system_message = brand_protect_replace_msg.format(
brands=req.brands_to_protect)
elif req.protection_mode == "neutralize":
logger.info(
f"Neutralize competitor brands of: {req.brands_to_protect}")
system_message = brand_protect_neutralize_msg.format(
brands=req.brands_to_protect)
else:
logger.info(f"No brand protection specified.")
return ImagePromptBrandProtectionResponse(enhanced_prompt=req.original_prompt)

# Ensure LLM client is available
Expand Down
30 changes: 27 additions & 3 deletions backend/api/endpoints/videos.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,12 +280,28 @@ def create_video_generation_with_analysis(req: VideoGenerationWithAnalysisReques
from backend.core.azure_storage import AzureBlobStorageService
from azure.storage.blob import ContentSettings

# Generate the final filename for gallery
final_filename = f"{req.prompt.replace(' ', '_')}_{generation_id}.mp4"

# Create Azure storage service
azure_service = AzureBlobStorageService()

# Generate the base filename
base_filename = f"{req.prompt.replace(' ', '_')}_{generation_id}.mp4"

# Extract folder path from request metadata and normalize it
folder_path = req.metadata.get(
'folder') if req.metadata else None
final_filename = base_filename

if folder_path and folder_path != 'root':
# Use Azure service's normalize_folder_path method for consistency
normalized_folder = azure_service.normalize_folder_path(
folder_path)
final_filename = f"{normalized_folder}{base_filename}"
logger.info(
f"Uploading video to folder: {normalized_folder}")
else:
logger.info(
"Uploading video to root directory")

# Upload to Azure Blob Storage
container_client = azure_service.blob_service_client.get_container_client(
"videos")
Expand All @@ -305,6 +321,11 @@ def create_video_generation_with_analysis(req: VideoGenerationWithAnalysisReques
"upload_date": datetime.now().isoformat()
}

# Add folder path to metadata if specified
if folder_path and folder_path != 'root':
upload_metadata["folder_path"] = azure_service.normalize_folder_path(
folder_path)

# Read the file and upload with metadata
with open(downloaded_path, 'rb') as video_file:
blob_client.upload_blob(
Expand All @@ -318,6 +339,9 @@ def create_video_generation_with_analysis(req: VideoGenerationWithAnalysisReques
blob_url = blob_client.url
logger.info(
f"Uploaded video to gallery: {blob_url}")
if folder_path and folder_path != 'root':
logger.info(
f"Video uploaded to folder '{folder_path}' with normalized path '{azure_service.normalize_folder_path(folder_path)}'")

except Exception as upload_error:
logger.warning(
Expand Down
Loading