|
| 1 | +import uuid |
| 2 | + |
| 3 | +from azure.cosmos import CosmosClient |
| 4 | +from langchain_core.documents import Document |
| 5 | +from langchain_openai import AzureOpenAIEmbeddings |
| 6 | + |
| 7 | +from template_fastapi.models.restaurant import Restaurant |
| 8 | +from template_fastapi.settings.azure_cosmosdb import get_azure_cosmosdb_settings |
| 9 | +from template_fastapi.settings.azure_openai import get_azure_openai_settings |
| 10 | + |
| 11 | +# 設定の取得 |
| 12 | +azure_cosmosdb_settings = get_azure_cosmosdb_settings() |
| 13 | +azure_openai_settings = get_azure_openai_settings() |
| 14 | + |
| 15 | + |
| 16 | +class RestaurantRepository: |
| 17 | + """レストランデータを管理するリポジトリクラス""" |
| 18 | + |
| 19 | + def __init__(self): |
| 20 | + self.container = self._setup_cosmos_client() |
| 21 | + |
| 22 | + def _setup_cosmos_client(self): |
| 23 | + """Azure Cosmos DBに接続するクライアントを設定する""" |
| 24 | + client = CosmosClient.from_connection_string(azure_cosmosdb_settings.azure_cosmosdb_connection_string) |
| 25 | + db = client.get_database_client(azure_cosmosdb_settings.azure_cosmosdb_database_name) |
| 26 | + container = db.get_container_client(azure_cosmosdb_settings.azure_cosmosdb_container_name) |
| 27 | + return container |
| 28 | + |
| 29 | + def _get_embeddings(self, text: str) -> list[float]: |
| 30 | + """Azure OpenAIを使用してテキストのベクトル埋め込みを生成する""" |
| 31 | + embedding_model = AzureOpenAIEmbeddings( |
| 32 | + azure_endpoint=azure_openai_settings.azure_openai_endpoint, |
| 33 | + api_key=azure_openai_settings.azure_openai_api_key, |
| 34 | + azure_deployment=azure_openai_settings.azure_openai_model_embedding, |
| 35 | + api_version=azure_openai_settings.azure_openai_api_version, |
| 36 | + ) |
| 37 | + |
| 38 | + document = Document(page_content=text) |
| 39 | + embedding = embedding_model.embed_documents([document.page_content])[0] |
| 40 | + return embedding |
| 41 | + |
| 42 | + def _cosmos_item_to_restaurant(self, item: dict) -> Restaurant: |
| 43 | + """CosmosDBのアイテムをRestaurantモデルに変換する""" |
| 44 | + # 位置情報の取り出し |
| 45 | + latitude = None |
| 46 | + longitude = None |
| 47 | + if "location" in item and "coordinates" in item["location"]: |
| 48 | + longitude, latitude = item["location"]["coordinates"] |
| 49 | + |
| 50 | + return Restaurant( |
| 51 | + id=item.get("id"), |
| 52 | + name=item.get("name"), |
| 53 | + description=item.get("description"), |
| 54 | + price=float(item.get("price", 0)), |
| 55 | + latitude=latitude, |
| 56 | + longitude=longitude, |
| 57 | + tags=item.get("tags", []), |
| 58 | + ) |
| 59 | + |
| 60 | + def list_restaurants(self, limit: int = 10) -> list[Restaurant]: |
| 61 | + """レストラン一覧を取得する""" |
| 62 | + query = f"SELECT TOP {limit} * FROM c" |
| 63 | + items = list(self.container.query_items(query=query, enable_cross_partition_query=True)) |
| 64 | + return [self._cosmos_item_to_restaurant(item) for item in items] |
| 65 | + |
| 66 | + def get_restaurant(self, restaurant_id: str) -> Restaurant: |
| 67 | + """指定されたIDのレストラン情報を取得する""" |
| 68 | + item = self.container.read_item(item=restaurant_id, partition_key=restaurant_id) |
| 69 | + return self._cosmos_item_to_restaurant(item) |
| 70 | + |
| 71 | + def create_restaurant(self, restaurant: Restaurant) -> Restaurant: |
| 72 | + """新しいレストランを作成する""" |
| 73 | + # IDが指定されていない場合は自動生成 |
| 74 | + if not restaurant.id: |
| 75 | + restaurant.id = str(uuid.uuid4()) |
| 76 | + |
| 77 | + # ベクトル埋め込みの生成 |
| 78 | + description = restaurant.description or restaurant.name |
| 79 | + vector_embedding = self._get_embeddings(description) |
| 80 | + |
| 81 | + # 位置情報の構築 |
| 82 | + location = None |
| 83 | + if restaurant.latitude is not None and restaurant.longitude is not None: |
| 84 | + location = {"type": "Point", "coordinates": [restaurant.longitude, restaurant.latitude]} |
| 85 | + |
| 86 | + # CosmosDBに保存するアイテムの作成 |
| 87 | + item = { |
| 88 | + "id": restaurant.id, |
| 89 | + "name": restaurant.name, |
| 90 | + "description": restaurant.description, |
| 91 | + "price": restaurant.price, |
| 92 | + "tags": restaurant.tags, |
| 93 | + "vector": vector_embedding, |
| 94 | + } |
| 95 | + |
| 96 | + if location: |
| 97 | + item["location"] = location |
| 98 | + |
| 99 | + # CosmosDBに保存 |
| 100 | + created_item = self.container.create_item(body=item) |
| 101 | + return self._cosmos_item_to_restaurant(created_item) |
| 102 | + |
| 103 | + def update_restaurant(self, restaurant_id: str, restaurant: Restaurant) -> Restaurant: |
| 104 | + """既存のレストラン情報を更新する""" |
| 105 | + # 既存のアイテムを取得 |
| 106 | + existing_item = self.container.read_item(item=restaurant_id, partition_key=restaurant_id) |
| 107 | + |
| 108 | + # 説明文が変更された場合、新しいベクトル埋め込みを生成 |
| 109 | + description = restaurant.description or restaurant.name |
| 110 | + if description != (existing_item.get("description") or existing_item.get("name")): |
| 111 | + vector_embedding = self._get_embeddings(description) |
| 112 | + else: |
| 113 | + vector_embedding = existing_item.get("vector") |
| 114 | + |
| 115 | + # 位置情報の構築 |
| 116 | + location = None |
| 117 | + if restaurant.latitude is not None and restaurant.longitude is not None: |
| 118 | + location = {"type": "Point", "coordinates": [restaurant.longitude, restaurant.latitude]} |
| 119 | + |
| 120 | + # 更新するアイテムの作成 |
| 121 | + updated_item = { |
| 122 | + "id": restaurant_id, |
| 123 | + "name": restaurant.name, |
| 124 | + "description": restaurant.description, |
| 125 | + "price": restaurant.price, |
| 126 | + "tags": restaurant.tags, |
| 127 | + "vector": vector_embedding, |
| 128 | + } |
| 129 | + |
| 130 | + if location: |
| 131 | + updated_item["location"] = location |
| 132 | + |
| 133 | + # CosmosDBのアイテムを更新 |
| 134 | + result = self.container.replace_item(item=restaurant_id, body=updated_item) |
| 135 | + return self._cosmos_item_to_restaurant(result) |
| 136 | + |
| 137 | + def delete_restaurant(self, restaurant_id: str) -> None: |
| 138 | + """指定されたIDのレストランを削除する""" |
| 139 | + self.container.delete_item(item=restaurant_id, partition_key=restaurant_id) |
| 140 | + |
| 141 | + def search_restaurants(self, query: str, k: int = 3) -> list[Restaurant]: |
| 142 | + """キーワードによるレストランのベクトル検索を実行する""" |
| 143 | + # クエリテキストのベクトル埋め込みを生成 |
| 144 | + query_embedding = self._get_embeddings(query) |
| 145 | + |
| 146 | + # ベクトル検索クエリの実行 |
| 147 | + query_text = f""" |
| 148 | + SELECT TOP {k} * |
| 149 | + FROM c |
| 150 | + ORDER BY VectorDistance(c.vector, @queryVector) |
| 151 | + """ |
| 152 | + |
| 153 | + parameters = [{"name": "@queryVector", "value": query_embedding}] |
| 154 | + items = list( |
| 155 | + self.container.query_items(query=query_text, parameters=parameters, enable_cross_partition_query=True) |
| 156 | + ) |
| 157 | + return [self._cosmos_item_to_restaurant(item) for item in items] |
| 158 | + |
| 159 | + def find_nearby_restaurants( |
| 160 | + self, latitude: float, longitude: float, distance_km: float = 5.0, limit: int = 10 |
| 161 | + ) -> list[Restaurant]: |
| 162 | + """指定した位置の近くにあるレストランを検索する""" |
| 163 | + # 地理空間クエリの実行(メートル単位で距離を指定) |
| 164 | + distance_meters = distance_km * 1000 |
| 165 | + query_text = f""" |
| 166 | + SELECT TOP {limit} * |
| 167 | + FROM c |
| 168 | + WHERE ST_DISTANCE(c.location, {{ |
| 169 | + "type": "Point", |
| 170 | + "coordinates": [{longitude}, {latitude}] |
| 171 | + }}) < {distance_meters} |
| 172 | + """ |
| 173 | + |
| 174 | + items = list(self.container.query_items(query=query_text, enable_cross_partition_query=True)) |
| 175 | + return [self._cosmos_item_to_restaurant(item) for item in items] |
0 commit comments