-
Notifications
You must be signed in to change notification settings - Fork 13
Expand file tree
/
Copy pathattached_file.py
More file actions
599 lines (492 loc) Β· 17.1 KB
/
attached_file.py
File metadata and controls
599 lines (492 loc) Β· 17.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
"""Module for managing attached files in conversations."""
from __future__ import annotations
import base64
import enum
import logging
import mimetypes
import re
from io import BufferedReader, BufferedWriter, BytesIO
from typing import Any
import httpx
from PIL import Image
from pydantic import (
BaseModel,
Field,
SerializationInfo,
SerializerFunctionWrapHandler,
ValidationInfo,
ValidatorFunctionWrapHandler,
field_serializer,
field_validator,
model_validator,
)
from upath import UPath
from upath.implementations.http import HTTPPath
from basilisk.decorators import measure_time
log = logging.getLogger(__name__)
URL_PATTERN = re.compile(r'https?://[^\s<>"]+|data:\S+', re.IGNORECASE)
def get_image_dimensions(reader: BufferedReader) -> tuple[int, int]:
"""Get the dimensions of an image.
Args:
reader: A file-like object containing the image data.
Returns:
A tuple containing the width and height of the image.
"""
img = Image.open(reader)
return img.size
def resize_image(
src: BufferedReader,
target: BufferedWriter,
format: str,
max_width: int = 0,
max_height: int = 0,
quality: int = 85,
) -> bool:
"""Compress an image and save it to a specified file.
Args:
src: path to the source image.
format: The format of the compressed image (e.g., "JPEG", "PNG").
max_width: Maximum width for the compressed image. If 0, only `max_height` is used to calculate the ratio.
max_height: Maximum height for the compressed image. If 0, only `max_width` is used to calculate the ratio.
quality: the quality of the compressed image (1-100).
target: Output path for the compressed image file.
Returns:
True if the image was successfully compressed and saved, False otherwise
"""
if max_width <= 0 and max_height <= 0:
log.debug("No resizing needed")
return False
image = Image.open(src)
if image.mode in ("RGBA", "P"):
image = image.convert("RGB")
orig_width, orig_height = image.size
if orig_width <= max_width and orig_height <= max_height:
log.debug("Image is already smaller than max dimensions")
return False
if max_width > 0 and max_height > 0:
ratio = min(max_width / orig_width, max_height / orig_height)
elif max_width > 0:
ratio = max_width / orig_width
else:
ratio = max_height / orig_height
new_width = int(orig_width * ratio)
new_height = int(orig_height * ratio)
resized_image = image.resize(
(new_width, new_height), Image.Resampling.LANCZOS
)
resized_image.save(target, optimize=True, quality=quality, format=format)
return True
def parse_supported_attachment_formats(
supported_attachment_formats: set[str],
) -> str:
"""Parse the supported attachment formats into a wildcard string for use in file dialogs.
Args:
supported_attachment_formats: A set of supported attachment formats (MIME types).
Returns:
A wildcard string containing all supported attachment formats.
"""
wildcard_parts = []
for mime_type in sorted(supported_attachment_formats):
exts = mimetypes.guess_all_extensions(mime_type)
if exts:
log.debug("Adding wildcard for MIME type %s: %s", mime_type, exts)
wildcard_parts.append("*" + ";*".join(exts))
else:
log.warning("No extensions found for MIME type %s", mime_type)
wildcard = ";".join(wildcard_parts)
return wildcard
def get_mime_type(path: UPath | str) -> str | None:
"""Get the MIME type of a file.
Args:
path: The path to the file.
Returns:
The MIME type of the file, or None if the type cannot be determined.
"""
if isinstance(path, UPath):
return mimetypes.guess_type(path.as_uri())[0]
return mimetypes.guess_type(path)[0]
@measure_time
def build_from_url(url: str) -> AttachmentFile:
"""Fetch a file from a given URL and create an AttachmentFile instance.
This class method retrieves a file from the specified URL and constructs an AttachmentFile with metadata about the file.
Args:
url: The URL of the file to retrieve.
Returns:
An instance of AttachmentFile with details about the retrieved file.
Raises:
httpx.HTTPError: If there is an error during the HTTP request.
Example:
file = build_from_url("https://example.com/file.pdf")
image = build_from_url("https://example.com/image.jpg")
"""
r = httpx.get(url, follow_redirects=True)
r.raise_for_status()
size = r.headers.get("Content-Length")
if size and size.isdigit():
size = int(size)
mime_type = r.headers.get("content-type", None)
if not mime_type:
raise NotImageError("No MIME type found")
if mime_type.startswith("image/"):
dimensions = get_image_dimensions(BytesIO(r.content))
return ImageFile(
location=url,
type=AttachmentFileTypes.URL,
size=size,
mime_type=mime_type,
dimensions=dimensions,
)
return AttachmentFile(
location=url,
type=AttachmentFileTypes.URL,
size=size,
mime_type=mime_type,
)
class AttachmentFileTypes(enum.StrEnum):
"""Enumeration of file types based on their source location."""
# The file type is unknown.
UNKNOWN = enum.auto()
# The file is stored on the local filesystem.
LOCAL = enum.auto()
# The file is stored in memory (RAM).
MEMORY = enum.auto()
# The file is stored at a URL.
URL = enum.auto()
@classmethod
def _missing_(cls, value: object) -> AttachmentFileTypes:
"""Determine the enum value for a given input value.
This method is a custom implementation for handling enum value mapping when a non-standard value is provided. It maps specific string inputs to predefined ImageFileTypes.
The mapping is as follows:
- "http", "https", "data" -> AttachmentFileTypes.URL
- "zip" -> AttachmentFileTypes.LOCAL
- Any other value -> AttachmentFileTypes.UNKNOWN
Args:
value: The input value to be mapped to an ImageFileTypes enum.
Returns:
The corresponding AttachmentFileTypes enum value for the input value.
"""
if not isinstance(value, str):
return cls.UNKNOWN
value = value.lower()
if value in {"data", "http", "https"}:
return cls.URL
if value == "zip":
return cls.LOCAL
return cls.UNKNOWN
class NotImageError(ValueError):
"""Exception raised when a URL does not point to an image file."""
pass
class AttachmentFile(BaseModel):
"""Represents an attached file in a conversation."""
location: UPath
name: str | None = None
description: str | None = None
size: int | None = None
mime_type: str | None = None
db_id: int | None = Field(default=None, exclude=True)
@field_serializer("location", mode="wrap")
@classmethod
def change_location(
cls,
value: UPath,
wrap_handler: SerializerFunctionWrapHandler,
info: SerializationInfo,
) -> UPath | str:
"""Serialize the location field with optional context-based mapping.
This method is a field serializer for the `location` attribute that allows dynamic
path translation based on a provided mapping context. If no mapping is available,
it returns the original value using the default serialization handler.
Args:
value: The original location path to be serialized.
wrap_handler: The default serialization handler.
info: Serialization context information.
Returns:
The serialized location path, potentially remapped based on context.
"""
if isinstance(value, HTTPPath):
return str(value)
if not info.context:
return wrap_handler(value)
mapping = info.context.get("attachment_mapping")
if not mapping:
return wrap_handler(value)
return mapping.get(value, wrap_handler(value))
@field_validator("location", mode="wrap")
@classmethod
def validate_location(
cls,
value: Any,
wrap_handler: ValidatorFunctionWrapHandler,
info: ValidationInfo,
) -> str | UPath:
"""Validates and transforms the location of an image file.
This method ensures that the location is either a valid string or a UPath instance.
If a string is provided without a protocol and a root path is available in the context,
it prepends the root path to create an absolute path.
Args:
value: The location value to validate, which can be a string or UPath.
wrap_handler: The default validation handler.
info: Validation context containing additional information.
Returns:
A validated and potentially transformed location.
Raises:
ValueError: If the location is not a string or UPath instance.
"""
if isinstance(value, str):
if URL_PATTERN.match(value):
return wrap_handler(value)
if info.context:
root_path = info.context.get("root_path")
if root_path:
value = root_path / value
return value
if isinstance(value, dict):
if info.context:
root_path = info.context.get("root_path")
value = wrap_handler(value)
return root_path / value
return value
if not isinstance(value, UPath):
raise ValueError("Invalid location")
return value
def __init__(self, /, **kwargs: Any) -> None:
"""Initialize an AttachmentFile instance with optional data.
If no name is provided, automatically generates a name using the internal _get_name() method.
If no size is set, retrieves the file size using _get_size() method.
Args:
kwargs: Keyword arguments for initializing the AttachmentFile instance. Can include optional attributes like name, size, and description.
"""
super().__init__(**kwargs)
self.name = self.name or self._get_name()
self.mime_type = kwargs.get("mime_type") or self._get_mime_type()
self.size = kwargs.get("size") or self._get_size()
__init__.__pydantic_base_init__ = True
@model_validator(mode="after")
def validate_location_exist(self) -> AttachmentFile:
"""Validate the location of the file.
Raises:
FileNotFoundError: If the file does not exist.
"""
if self.type == AttachmentFileTypes.URL:
return self
if not self.location.exists():
raise FileNotFoundError(f"File {self.location} does not exist")
return self
@property
def type(self) -> AttachmentFileTypes:
"""Determine the type of file based on its location protocol.
Returns:
An enum value representing the file's source type, derived from the protocol of the file's location.
"""
if self.location.protocol in ("", "file"):
return AttachmentFileTypes.LOCAL
return AttachmentFileTypes(self.location.protocol)
def _get_name(self) -> str:
"""Get the name of the file.
Returns:
The name of the file, extracted from the file path.
"""
return self.location.name
def _get_size(self) -> int | None:
"""Get the size of the file.
Returns:
The size of the file in bytes, or None if the size cannot be determined
"""
if self.type == AttachmentFileTypes.URL:
return None
return self.location.stat().st_size
@property
def display_size(self) -> str:
"""Get the human-readable size of the file.
Returns:
The size of the file in a human-readable format (e.g., "1.23 MB").
"""
size = self.size
if size is None:
# Translators: Placeholder for an unknown file size
return _("Unknown")
if size < 1024:
return f"{size} B"
if size < 1024 * 1024:
return f"{size / 1024:.2f} KB"
return f"{size / 1024 / 1024:.2f} MB"
@property
def send_location(self) -> UPath:
"""Get the location of the file to send.
Returns:
The location of the file to send, which is the original location for URL files.
"""
return getattr(self, "resize_location", None) or self.location
def _get_mime_type(self) -> str | None:
"""Get the MIME type of the file.
Returns:
The MIME type of the file, or None if the type cannot be determined.
"""
if self.type == AttachmentFileTypes.URL:
return None
return get_mime_type(self.send_location)
@property
def display_location(self) -> str:
"""Get the display location of the file.
Returns:
The display location of the file, truncated if necessary.
"""
location = str(self.location)
if location.startswith("data:"):
location = f"{location[:50]}...{location[-10:]}"
return location
@staticmethod
def remove_location(location: UPath):
"""Remove a file at the specified location.
Args:
location: The location of the file to remove.
"""
log.debug("Removing file at %s", location)
try:
fs = location.fs
fs.rm(location.path)
except Exception as e:
log.error("Error deleting file at '%s': %s", location, e)
def _read_file(self, mode: str):
"""Read the file content using the specified mode.
Args:
mode: The file opening mode ("r" for text, "rb" for bytes).
Returns:
The file content in the format specified by the mode.
"""
# Use UTF-8 encoding for text modes to ensure consistent behavior across platforms
encoding = "utf-8" if "b" not in mode else None
open_kwargs = {"mode": mode}
if encoding:
open_kwargs["encoding"] = encoding
with self.send_location.open(**open_kwargs) as file:
return file.read()
def read_as_plain_text(self) -> str:
"""Read the file as a plain text string.
Returns:
The contents of the file as a plain text string.
"""
return self._read_file("r")
def read_as_bytes(self) -> bytes:
"""Read the file as bytes.
Returns:
The contents of the file as bytes.
"""
return self._read_file("rb")
def encode_base64(self) -> str:
"""Encode the file as a base64 string.
Returns:
A base64-encoded string representing the file.
"""
with self.send_location.open(mode="rb") as file:
return base64.b64encode(file.read()).decode("utf-8")
def __del__(self):
"""Delete the file."""
if self.type == AttachmentFileTypes.URL:
return
if self.type == AttachmentFileTypes.MEMORY:
self.remove_location(self.location)
def get_display_info(self) -> tuple[str, str, str]:
"""Get the name, size and location of the file.
Returns:
A tuple containing the name, size and location of the file.
"""
return self.name, self.display_size, self.display_location
@property
def url(self) -> str:
"""Get the URL of the file.
Returns:
The URL of the file, or the base64-encoded data if the file is in memory.
"""
if self.type == AttachmentFileTypes.URL:
return str(self.location)
base64_data = self.encode_base64()
return f"data:{self.mime_type};base64,{base64_data}"
class ImageFile(AttachmentFile):
"""Represents an image file in a conversation."""
dimensions: tuple[int, int] | None = None
resize_location: UPath | None = Field(default=None, exclude=True)
def __init__(self, /, **kwargs: Any) -> None:
"""Initialize an ImageFile instance with optional data.
If no name is provided, automatically generates a name using the internal _get_name() method.
If no size is set, retrieves the file size using _get_size() method.
If no dimensions are specified, determines image dimensions using _get_dimensions() method.
Args:
kwargs: Keyword arguments for initializing the ImageFile instance. Can include optional attributes like name, size, and dimensions.
"""
super().__init__(**kwargs)
self.dimensions = self.dimensions or self._get_dimensions()
__init__.__pydantic_base_init__ = True
def _get_dimensions(self) -> tuple[int, int] | None:
if self.type == AttachmentFileTypes.URL:
return None
with self.location.open(mode="rb") as image_file:
return get_image_dimensions(image_file)
@property
def display_dimensions(self) -> str:
"""Get the human-readable dimensions of the image.
Returns:
The dimensions of the image in a human-readable format (e.g., "1920 x 1080").
"""
if self.dimensions is None:
# Translators: Placeholder for unknown image dimensions
return _("Unknown")
return f"{self.dimensions[0]} x {self.dimensions[1]}"
@measure_time
def resize(
self, conv_folder: UPath, max_width: int, max_height: int, quality: int
):
"""Resize the image to specified dimensions and save to a new location.
This method resizes the image if it is not a URL, creating an optimized version
in the specified conversion folder. The original image remains unchanged.
Args:
conv_folder: Folder where the resized image will be saved
max_width: Maximum width for the resized image
max_height: Maximum height for the resized image
quality: Compression quality for the resized image (1-100)
"""
if AttachmentFileTypes.URL == self.type:
return
log.debug("Resizing image")
resize_folder = conv_folder.joinpath("optimized_images")
resize_folder.mkdir(parents=True, exist_ok=True)
resize_location = resize_folder / self.location.name
with self.location.open(mode="rb") as src_file:
with resize_location.open(mode="wb") as dst_file:
success = resize_image(
src_file,
max_width=max_width,
max_height=max_height,
quality=quality,
target=dst_file,
format=self.location.suffix[1:],
)
self.resize_location = resize_location if success else None
@measure_time
def encode_base64(self) -> str:
"""Encode the image file as a base64 string.
Returns:
A base64-encoded string representing the image file.
"""
if self.size and self.size > 1024 * 1024 * 1024:
log.warning(
"Large image (%s) being encoded to base64", self.display_size
)
return super().encode_base64()
@property
def display_location(self):
"""Get the display location of the image file.
Returns:
The display location of the image file, truncated if necessary.
"""
location = str(self.location)
if location.startswith("data:image/"):
location = f"{location[:50]}...{location[-10:]}"
return location
def __del__(self):
"""Delete the image file and its resized version."""
if self.type == AttachmentFileTypes.URL:
return
if self.resize_location:
self.remove_location(self.resize_location)
super().__del__()