Skip to content
This repository was archived by the owner on Jun 12, 2024. It is now read-only.

Commit dd805eb

Browse files
committed
fix: gemini image
1 parent 24ce19f commit dd805eb

File tree

1 file changed

+42
-94
lines changed

1 file changed

+42
-94
lines changed

gemini/src/model/image.py

Lines changed: 42 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1+
import os
12
import random
23
import httpx
34
import asyncio
45
import datetime
5-
import os
66
from pathlib import Path
77
from typing import List, Optional, Dict
88
from loguru import logger
@@ -14,121 +14,69 @@ class GeminiImage(BaseModel):
1414
title: str = "[Image]"
1515
alt: str = ""
1616

17-
@staticmethod
18-
def validate_images(images):
19-
if images == []:
17+
@classmethod
18+
def validate_images(cls, images):
19+
if not images:
2020
raise ValueError("Input is empty. Please provide images to proceed.")
2121

2222
@staticmethod
23-
async def fetch_bytes(
24-
url: HttpUrl, cookies: Optional[dict] = None
25-
) -> Optional[bytes]:
26-
"""Asynchronously fetches the bytes data of an image from the given URL.
27-
28-
Args:
29-
url (str): The URL of the image.
30-
cookies (dict, optional): Cookies to be used for downloading the image.
31-
32-
Returns:
33-
Optional[bytes]: The bytes data of the image, or None if fetching fails.
34-
"""
23+
async def fetch_bytes(url: HttpUrl) -> Optional[bytes]:
3524
try:
36-
async with httpx.AsyncClient(
37-
follow_redirects=True, cookies=cookies
38-
) as client:
39-
response = await client.get(url)
25+
async with httpx.AsyncClient(follow_redirects=True) as client:
26+
response = await client.get(str(url))
4027
response.raise_for_status()
4128
return response.content
4229
except Exception as e:
4330
print(f"Failed to download {url}: {str(e)}")
44-
pass
31+
return None
4532

33+
@classmethod
4634
async def save(
47-
self,
48-
save_path: str = "cached",
49-
filename: Optional[str] = None,
50-
cookies: Optional[dict] = None,
35+
cls, images: List["GeminiImage"], save_path: str = "cached"
5136
) -> Optional[Path]:
52-
"""
53-
Downloads the image from the URL and saves it to the specified save_path.
54-
55-
Args:
56-
save_path: The directory to save the image (default: "cached").
57-
filename: The filename for the saved image (optional, uses title if not provided).
58-
cookies: Optional cookies dictionary for the download request.
59-
60-
Returns:
61-
The save_path to the saved image file or None if download fails.
62-
"""
63-
image_data = await self.fetch_bytes(self.url, cookies)
64-
GeminiImage.validate_images(image_data)
65-
now = datetime.datetime.now().strftime("%y%m%d%H%M%S")
66-
67-
if not filename:
68-
filename = f"{self.title.replace(' ', '_').replace('[Image]', '').replace('/', '_').replace(':', '_')}_{random.randint(10,99)}_{now}.jpg"
37+
cls.validate_images(images)
38+
image_data = await cls.fetch_images_dict(images)
39+
await cls.save_images(image_data, save_path)
6940

70-
save_path = Path(save_path) / filename
71-
save_path.parent.mkdir(parents=True, exist_ok=True)
41+
@classmethod
42+
async def fetch_images_dict(cls, images: List["GeminiImage"]) -> Dict[str, bytes]:
43+
cls.validate_images(images)
44+
tasks = [cls.fetch_bytes(image.url) for image in images]
45+
results = await asyncio.gather(*tasks)
46+
return {image.title: result for image, result in zip(images, results) if result}
7247

73-
try:
74-
save_path.write_bytes(image_data)
75-
print(f"Saved {self.title} to {save_path}")
76-
return save_path
77-
except Exception as e:
78-
print(f"Failed to save image {self.title}: {str(e)}")
79-
return None
48+
@staticmethod
49+
async def save_images(image_data: Dict[str, bytes], save_path: str = "cached"):
50+
os.makedirs(save_path, exist_ok=True)
51+
for title, data in image_data.items():
52+
now = datetime.datetime.now().strftime("%Y%m%d%H%M%S%f")
53+
filename = f"{title.replace(' ', '_')}_{now}.jpg"
54+
filepath = Path(save_path) / filename
55+
try:
56+
with open(filepath, "wb") as f:
57+
f.write(data)
58+
print(f"Saved {title} to {filepath}")
59+
except Exception as e:
60+
print(f"Error saving {title}: {str(e)}")
8061

8162
@staticmethod
82-
async def fetch_images_dict(
83-
images: List["GeminiImage"], cookies: Optional[dict] = None
84-
) -> Dict[str, bytes]:
85-
"""Asynchronously fetches bytes data for a list of images.
63+
def fetch_bytes_sync(url: HttpUrl) -> Optional[bytes]:
64+
"""Synchronously fetches the bytes data of an image from the given URL.
8665
8766
Args:
88-
images (List[GeminiImage]): A list of GeminiImage objects.
89-
cookies (dict, optional): Cookies to be used for downloading the images.
67+
url (str): The URL of the image.
9068
9169
Returns:
92-
Dict[str, bytes]: A dictionary mapping image titles to bytes data.
70+
Optional[bytes]: The bytes data of the image, or None if fetching fails.
9371
"""
94-
GeminiImage.validate_images(images)
9572
try:
96-
tasks = [image.fetch_bytes(image.url, cookies) for image in images]
97-
results = await asyncio.gather(*tasks, return_exceptions=True)
98-
return {
99-
images[i].title: result
100-
for i, result in enumerate(results)
101-
if isinstance(result, bytes)
102-
}
73+
with httpx.Client(follow_redirects=True) as client:
74+
response = client.get(str(url))
75+
response.raise_for_status()
76+
return response.content
10377
except Exception as e:
104-
print(f"Error fetching images: {str(e)}")
105-
return {}
106-
107-
@staticmethod
108-
async def save_images(
109-
image_data: Dict[str, bytes], save_path: str = "cached", unique: bool = True
110-
):
111-
"""Asynchronously saves images specified by their bytes data.
112-
113-
Args:
114-
image_data (Dict[str, bytes]): A dictionary mapping image titles to bytes data.
115-
path (str, optional): The directory where the images will be saved. Defaults to "images".
116-
"""
117-
os.makedirs(save_path, exist_ok=True)
118-
if unique:
119-
titles = set(image_data.keys())
120-
image_data = {
121-
title: data for title, data in image_data.items() if title in titles
122-
}
123-
for title, data in image_data.items():
124-
now = datetime.datetime.now().strftime("%y%m%d%H%M%S%f")
125-
filename = f"{title.replace(' ', '_').replace('[Image]', '').replace('/', '_').replace(':', '_')}_{random.randint(10,99)}_{now}.jpg"
126-
filepath = Path(save_path) / filename
127-
try:
128-
filepath.write_bytes(data)
129-
print(f"Saved {title} to {save_path}")
130-
except:
131-
pass
78+
print(f"Failed to download {url}: {str(e)}")
79+
return None
13280

13381
@staticmethod
13482
def fetch_bytes_sync(

0 commit comments

Comments
 (0)