|
1 | 1 | from abc import ABC
|
2 | 2 | from collections.abc import AsyncGenerator, Awaitable
|
3 |
| -from dataclasses import dataclass |
| 3 | +from dataclasses import dataclass, field |
4 | 4 | from enum import Enum
|
5 | 5 | from typing import Any, Callable, Optional, TypedDict, Union, cast
|
6 | 6 | from urllib.parse import urljoin
|
@@ -116,7 +116,7 @@ class DataPoints:
|
116 | 116 | @dataclass
|
117 | 117 | class ExtraInfo:
|
118 | 118 | data_points: DataPoints
|
119 |
| - thoughts: Optional[list[ThoughtStep]] = None |
| 119 | + thoughts: list[ThoughtStep] = field(default_factory=list) |
120 | 120 | followup_questions: Optional[list[Any]] = None
|
121 | 121 |
|
122 | 122 |
|
@@ -395,6 +395,8 @@ def nonewlines(s: str) -> str:
|
395 | 395 | text_sources.append(f"{citation}: {nonewlines(doc.content or '')}")
|
396 | 396 |
|
397 | 397 | if use_image_sources and hasattr(doc, "images") and doc.images:
|
| 398 | + if self.images_blob_container_client is None: |
| 399 | + raise ValueError("The images blob container client must be set to use image sources.") |
398 | 400 | for img in doc.images:
|
399 | 401 | # Skip if we've already processed this URL
|
400 | 402 | if img["url"] in seen_urls:
|
@@ -440,11 +442,15 @@ class ExtraArgs(TypedDict, total=False):
|
440 | 442 | return VectorizedQuery(vector=query_vector, k_nearest_neighbors=50, fields=self.embedding_field)
|
441 | 443 |
|
442 | 444 | async def compute_image_embedding(self, q: str):
|
| 445 | + if not self.vision_endpoint: |
| 446 | + raise ValueError("Azure AI Vision endpoint must be set to compute image embedding.") |
443 | 447 | endpoint = urljoin(self.vision_endpoint, "computervision/retrieval:vectorizeText")
|
444 | 448 | headers = {"Content-Type": "application/json"}
|
445 | 449 | params = {"api-version": "2024-02-01", "model-version": "2023-04-15"}
|
446 | 450 | data = {"text": q}
|
447 | 451 |
|
| 452 | + if not self.vision_token_provider: |
| 453 | + raise ValueError("Azure AI Vision token provider must be set to compute image embedding.") |
448 | 454 | headers["Authorization"] = "Bearer " + await self.vision_token_provider()
|
449 | 455 |
|
450 | 456 | async with aiohttp.ClientSession() as session:
|
|
0 commit comments