|
6 | 6 | import networkx as nx |
7 | 7 | import pandas as pd |
8 | 8 | import logging |
| 9 | +import requests |
9 | 10 | import matplotlib.pyplot as plt |
10 | 11 | from pyvis.network import Network |
11 | 12 |
|
@@ -343,6 +344,147 @@ def __init__(self, id: int, name: str, file_path: str, title: str=None, caption: |
343 | 344 | Initializes a DataFrame object. |
344 | 345 | """ |
345 | 346 | super().__init__(id, name, file_path, component_type=ComponentType.MARKDOWN, title=title, caption=caption, logger=logger) |
| 347 | + |
| 348 | +class APICall(Component): |
| 349 | + """ |
| 350 | + A component for interacting with APIs in a report. |
| 351 | +
|
| 352 | + Attributes |
| 353 | + ---------- |
| 354 | + api_url : str |
| 355 | + The URL of the API to interact with. |
| 356 | + headers : Optional[dict] |
| 357 | + Headers to include in the API request (default is None). |
| 358 | + params : Optional[dict] |
| 359 | + Query parameters to include in the API request (default is None). |
| 360 | + """ |
| 361 | + def __init__(self, id: int, name: str, file_path: str, api_url: str, |
| 362 | + title: str = None, caption: str = None, logger: Optional[logging.Logger] = None, |
| 363 | + headers: Optional[dict] = None, params: Optional[dict] = None): |
| 364 | + super().__init__(id, name, file_path, component_type=ComponentType.MARKDOWN, |
| 365 | + title=title, caption=caption, logger=logger) |
| 366 | + self.api_url = api_url |
| 367 | + self.headers = headers or {} |
| 368 | + self.params = params or {} |
| 369 | + |
| 370 | + def make_api_request(self, method: str = "GET", payload: Optional[dict] = None) -> Optional[dict]: |
| 371 | + """ |
| 372 | + Initiates an API request. |
| 373 | +
|
| 374 | + Parameters |
| 375 | + ---------- |
| 376 | + method : str, optional |
| 377 | + HTTP method to use for the request (default is "GET"). |
| 378 | + payload : Optional[dict], optional |
| 379 | + The request payload for POST or PUT methods (default is None). |
| 380 | +
|
| 381 | + Returns |
| 382 | + ------- |
| 383 | + response : Optional[dict] |
| 384 | + The JSON response from the API, or None if the request fails. |
| 385 | + """ |
| 386 | + try: |
| 387 | + self.logger.info(f"Making {method} request to API: {self.api_url}") |
| 388 | + response = requests.request(method, self.api_url, headers=self.headers, params=self.params, json=payload) |
| 389 | + response.raise_for_status() |
| 390 | + self.logger.info(f"Request successful with status code {response.status_code}.") |
| 391 | + return response.json() |
| 392 | + except requests.exceptions.RequestException as e: |
| 393 | + self.logger.error(f"API request failed: {e}") |
| 394 | + return None |
| 395 | + |
| 396 | + def parse_api_response(self, response: Optional[dict], key: Optional[str] = None) -> Optional[any]: |
| 397 | + """ |
| 398 | + Extracts and processes data from the API response. |
| 399 | +
|
| 400 | + Parameters |
| 401 | + ---------- |
| 402 | + response : Optional[dict] |
| 403 | + The response from the API. |
| 404 | + key : Optional[str], optional |
| 405 | + A specific key to retrieve from the response (default is None). |
| 406 | +
|
| 407 | + Returns |
| 408 | + ------- |
| 409 | + result : Optional[any] |
| 410 | + The extracted data from the response, or None if the key is not found. |
| 411 | + """ |
| 412 | + if not response: |
| 413 | + self.logger.error("No response to parse.") |
| 414 | + return None |
| 415 | + |
| 416 | + try: |
| 417 | + if key: |
| 418 | + self.logger.info(f"Parsing response for key: {key}") |
| 419 | + return response.get(key, None) |
| 420 | + self.logger.info("Returning full API response.") |
| 421 | + return response |
| 422 | + except Exception as e: |
| 423 | + self.logger.error(f"Failed to parse API response: {e}") |
| 424 | + return None |
| 425 | + |
| 426 | +class RAG(APICall): |
| 427 | + """ |
| 428 | + A specialized component for interacting with Retrieval-Augmented Generation APIs. |
| 429 | +
|
| 430 | + Attributes |
| 431 | + ---------- |
| 432 | + model_id : str |
| 433 | + The ID of the language model to use for retrieval. |
| 434 | + top_k : int |
| 435 | + The number of results to retrieve (default is 5). |
| 436 | + """ |
| 437 | + def __init__(self, id: int, name: str, file_path: str, api_url: str, model_id: str, |
| 438 | + top_k: int = 5, title: str = None, caption: str = None, logger: Optional[logging.Logger] = None, |
| 439 | + headers: Optional[dict] = None, params: Optional[dict] = None): |
| 440 | + super().__init__(id, name, file_path, api_url, title=title, caption=caption, |
| 441 | + logger=logger, headers=headers, params=params) |
| 442 | + self.model_id = model_id |
| 443 | + self.top_k = top_k |
| 444 | + |
| 445 | + def generate_query(self, user_input: str) -> dict: |
| 446 | + """ |
| 447 | + Constructs the payload for a RAG query. |
| 448 | +
|
| 449 | + Parameters |
| 450 | + ---------- |
| 451 | + user_input : str |
| 452 | + The input query for retrieval. |
| 453 | +
|
| 454 | + Returns |
| 455 | + ------- |
| 456 | + payload : dict |
| 457 | + The payload for the RAG API request. |
| 458 | + """ |
| 459 | + self.logger.info(f"Generating query payload for input: {user_input}") |
| 460 | + return { |
| 461 | + "model_id": self.model_id, |
| 462 | + "input": user_input, |
| 463 | + "top_k": self.top_k |
| 464 | + } |
| 465 | + |
| 466 | + def get_retrieved_documents(self, user_input: str) -> Optional[list]: |
| 467 | + """ |
| 468 | + Sends a RAG query and retrieves the resulting documents. |
| 469 | +
|
| 470 | + Parameters |
| 471 | + ---------- |
| 472 | + user_input : str |
| 473 | + The input query for retrieval. |
| 474 | +
|
| 475 | + Returns |
| 476 | + ------- |
| 477 | + documents : Optional[list] |
| 478 | + A list of retrieved documents, or None if the request fails. |
| 479 | + """ |
| 480 | + payload = self.generate_query(user_input) |
| 481 | + response = self.make_api_request(method="POST", payload=payload) |
| 482 | + documents = self.parse_api_response(response, key="retrieved_documents") |
| 483 | + if documents: |
| 484 | + self.logger.info(f"Retrieved {len(documents)} documents.") |
| 485 | + else: |
| 486 | + self.logger.warning("No documents retrieved.") |
| 487 | + return documents |
346 | 488 |
|
347 | 489 | @dataclass |
348 | 490 | class Subsection: |
|
0 commit comments