Skip to content

Commit 9a38aa4

Browse files
committed
First version of the Chatbot component
1 parent c88572b commit 9a38aa4

File tree

2 files changed

+62
-27
lines changed

2 files changed

+62
-27
lines changed

vuegen/report.py

Lines changed: 33 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import pandas as pd
88
import logging
99
import requests
10+
import json
1011
import matplotlib.pyplot as plt
1112
from pyvis.network import Network
1213

@@ -21,6 +22,7 @@ class ComponentType(StrEnum):
2122
DATAFRAME = auto()
2223
MARKDOWN = auto()
2324
APICALL = auto()
25+
CHATBOT = auto()
2426

2527
class PlotType(StrEnum):
2628
INTERACTIVE = auto()
@@ -313,7 +315,6 @@ def _add_size_attribute(self, G: nx.Graph) -> nx.Graph:
313315
size = min_size + (max_size - min_size) * ((degree - min_degree) / (max_degree - min_degree))
314316

315317
G.nodes[node]['size'] = size
316-
317318
return G
318319

319320
class DataFrame(Component):
@@ -392,14 +393,14 @@ def make_api_request(self, method: str, request_body: Optional[dict] = None) ->
392393
self.logger.error(f"API request failed: {e}")
393394
return None
394395

395-
class RAG(APICall):
396+
class ChatBot(APICall):
396397
"""
397-
A specialized component for interacting with Retrieval-Augmented Generation APIs.
398+
A specialized component for creating a ChatBot.
398399
399400
Attributes
400401
----------
401402
model : str
402-
The language model to use for retrieval.
403+
The language model to use.
403404
"""
404405
def __init__(self, id: int, name: str, api_url: str, model: str, title: str = None,
405406
caption: str = None, logger: Optional[logging.Logger] = None,
@@ -408,66 +409,72 @@ def __init__(self, id: int, name: str, api_url: str, model: str, title: str = No
408409
logger=logger, headers=headers, params=params)
409410
self.model = model
410411

411-
def get_retrieved_documents(self, prompt: str) -> Optional[list]:
412+
def get_chatbot_answer(self, prompt: str) -> dict:
412413
"""
413414
Sends a RAG query and retrieves the resulting documents.
414415
415416
Parameters
416417
----------
417418
prompt : str
418-
The prompt for retrieval.
419+
The prompt for asking the chatbot.
419420
420421
Returns
421422
-------
422-
documents : Optional[list]
423-
A list of retrieved documents, or None if the request fails.
423+
parsed_response : dict
424+
The chabtbot answer.
424425
"""
425426
request_body = self._generate_query(prompt)
426427
response = self.make_api_request(method="POST", request_body=request_body)
427428
if response:
428429
self.logger.info(f"Request successful")
429430
else:
430431
self.logger.warning("Nothing retreived.")
431-
return response
432+
parsed_response = self._parse_api_response(response)
433+
return parsed_response
432434

433-
def _generate_query(self, prompt: str) -> dict:
435+
def _generate_query(self, messages: str) -> dict:
434436
"""
435-
Constructs the payload for a RAG query.
437+
Constructs the request body for a question to the chatbot.
436438
437439
Parameters
438440
----------
439-
prompt : str
440-
The prompt for retrieval.
441+
messages : str
442+
The messages for retrieval.
441443
442444
Returns
443445
-------
444-
payload : dict
445-
The payload for the RAG API request.
446+
request_body : dict
447+
The request body for the question to the chatbot.
446448
"""
447-
self.logger.info(f"Generating request body for prompt: {prompt}")
449+
self.logger.info(f"Generating request body for message: {messages}")
448450
return {
449451
"model": self.model,
450-
"prompt": prompt,
451-
"stream": False
452+
"messages": messages,
453+
"stream": True
452454
}
453455

454-
def _parse_api_response(self, response: Optional[dict], key: Optional[str] = None) -> Optional[any]:
456+
def _parse_api_response(self, response: dict) -> dict:
455457
"""
456458
Extracts and processes data from the API response.
457459
458460
Parameters
459461
----------
460-
response : Optional[dict]
462+
response : dict
461463
The response from the API.
462-
key : Optional[str], optional
463-
A specific key to retrieve from the response (default is None).
464464
465465
Returns
466466
-------
467-
result : Optional[any]
468-
The extracted data from the response, or None if the key is not found.
469-
"""
470-
pass
467+
output : dict
468+
The extracted data from the response.
469+
"""
470+
output = ""
471+
for line in response.iter_lines():
472+
body = json.loads(line)
473+
if "error" in body:
474+
raise Exception(body["error"])
475+
if body.get("done", False):
476+
return {"role": "assistant", "content": output}
477+
output += body.get("message", {}).get("content", "")
471478

472479
@dataclass
473480
class Subsection:

vuegen/streamlit_reportview.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,8 @@ def _generate_subsection(self, subsection) -> tuple[List[str], List[str]]:
270270
subsection_content.extend(self._generate_markdown_content(component))
271271
elif component.component_type == r.ComponentType.APICALL:
272272
subsection_content.extend(self._generate_apicall_content(component))
273+
elif component.component_type == r.ComponentType.CHATBOT:
274+
subsection_content.extend(self._generate_chatbot_content(component))
273275
else:
274276
self.report.logger.warning(f"Unsupported component type '{component.component_type}' in subsection: {subsection.name}")
275277

@@ -448,7 +450,7 @@ def _generate_apicall_content(self, apicall) -> List[str]:
448450
try:
449451
apicall_content = []
450452
apicall_content.append(self._format_text(text=apicall.title, type='header', level=4, color='#2b8cbe'))
451-
apicall_response = apicall.make_api_request()
453+
apicall_response = apicall.make_api_request(method='GET')
452454
apicall_content.append(f"""st.write({apicall_response})\n""")
453455
except Exception as e:
454456
self.report.logger.error(f"Error generating content for APICall: {apicall.title}. Error: {str(e)}")
@@ -457,6 +459,32 @@ def _generate_apicall_content(self, apicall) -> List[str]:
457459
self.report.logger.info(f"Successfully generated content for APICall: '{apicall.title}'")
458460
return apicall_content
459461

462+
def _generate_chatbot_content(self, chatbot) -> List[str]:
463+
"""
464+
Generate content for a Markdown component.
465+
466+
Parameters
467+
----------
468+
chatbot : ChatBot
469+
The chatbot component to generate content for.
470+
471+
Returns
472+
-------
473+
list : List[str]
474+
The list of content lines for the chatbot.
475+
"""
476+
try:
477+
apicall_content = []
478+
apicall_content.append(self._format_text(text=chatbot.title, type='header', level=4, color='#2b8cbe'))
479+
apicall_response = chatbot.get_chatbot_answer()
480+
apicall_content.append(f"""st.write({apicall_response})\n""")
481+
except Exception as e:
482+
self.report.logger.error(f"Error generating content for APICall: {chatbot.title}. Error: {str(e)}")
483+
raise
484+
485+
self.report.logger.info(f"Successfully generated content for APICall: '{chatbot.title}'")
486+
return apicall_content
487+
460488
def _generate_component_imports(self, component: r.Component) -> List[str]:
461489
"""
462490
Generate necessary imports for a component of the report.

0 commit comments

Comments
 (0)