Skip to content

Commit c803bfa

Browse files
committed
Mypy fixes
1 parent 47d7308 commit c803bfa

File tree

2 files changed

+27
-9
lines changed

2 files changed

+27
-9
lines changed

app/backend/prepdocslib/embeddings.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -255,15 +255,10 @@ async def create_embedding_for_image(self, image_bytes: bytes) -> list[float]:
255255
raise ValueError("Failed to get image embedding after multiple retries.")
256256

257257
async def create_embedding_for_text(self, q: str):
258-
if not self.endpoint:
259-
raise ValueError("Azure AI Vision endpoint must be set to compute image embedding.")
260258
endpoint = urljoin(self.endpoint, "computervision/retrieval:vectorizeText")
261259
headers = {"Content-Type": "application/json"}
262260
params = {"api-version": "2024-02-01", "model-version": "2023-04-15"}
263261
data = {"text": q}
264-
265-
if not self.token_provider:
266-
raise ValueError("Azure AI Vision token provider must be set to compute image embedding.")
267262
headers["Authorization"] = "Bearer " + await self.token_provider()
268263

269264
async with aiohttp.ClientSession() as session:

app/backend/prepdocslib/listfilestrategy.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,33 @@ def __init__(self, content: IO, acls: Optional[dict[str, list]] = None, url: Opt
2828
self.acls = acls or {}
2929
self.url = url
3030

31-
def filename(self):
32-
if not self.content.name or self.content.name == "file":
33-
return os.path.basename(self.content.filename)
34-
return os.path.basename(self.content.name)
31+
def filename(self) -> str:
32+
"""
33+
Get the filename from the content object.
34+
35+
Different file-like objects store the filename in different attributes:
36+
- File objects from open() have a .name attribute
37+
- HTTP uploaded files (werkzeug.datastructures.FileStorage) have a .filename attribute
38+
39+
Returns:
40+
str: The basename of the file
41+
"""
42+
content_name = None
43+
44+
# Try to get filename attribute (common for HTTP uploaded files)
45+
if hasattr(self.content, "filename"):
46+
content_name = getattr(self.content, "filename")
47+
if content_name:
48+
return os.path.basename(content_name)
49+
50+
# Try to get name attribute (common for file objects from open())
51+
if hasattr(self.content, "name"):
52+
content_name = getattr(self.content, "name")
53+
if content_name and content_name != "file":
54+
return os.path.basename(content_name)
55+
56+
# If we couldn't determine a name, return a default
57+
return "unknown"
3558

3659
def file_extension(self):
3760
return os.path.splitext(self.filename())[1]

0 commit comments

Comments
 (0)