Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions ai/gen-ai-agents/travel_agent/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Travel Agent
This repository contains all the code for a demo of a Travel Agent.
The AI Agent enables a customer to get information about available destinations and to organize a trip, book flight, hotel...

The agent has been developed using OCI Generative AI and LangGraph.

## List of packages
* oci
* langchain-community
* langgraph
* streamlit
* fastapi
* black
* uvicorn


63 changes: 63 additions & 0 deletions ai/gen-ai-agents/travel_agent/base_node.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
"""
Base Node class for LangGraph nodes.

This module defines a base class `BaseNode` for all LangGraph nodes,
providing a standard logging interface via `log_info` and `log_error` methods.
Each subclass should implement the `invoke(input, config=None)` method.
"""

import logging
from langchain_core.runnables import Runnable


class BaseNode(Runnable):
"""
Abstract base class for LangGraph nodes.

All node classes in the graph should inherit from this base class.
It provides convenient logging utilities and stores a unique node name
for identification in logs and debugging.

Attributes:
name (str): Identifier for the node, used in logging.
logger (logging.Logger): Configured logger instance for the node.
"""

def __init__(self, name: str):
"""
Initialize the base node with a logger.

Args:
name (str): Unique name of the node for logging purposes.
"""
self.name = name
self.logger = logging.getLogger(name)
self.logger.setLevel(logging.INFO)

# Attach a default console handler if no handlers are present
if not self.logger.handlers:
handler = logging.StreamHandler()
handler.setLevel(logging.INFO)
formatter = logging.Formatter(
"[%(asctime)s] %(levelname)s in %(name)s: %(message)s"
)
handler.setFormatter(formatter)
self.logger.addHandler(handler)

def log_info(self, message: str):
"""
Log an informational message.

Args:
message (str): The message to log.
"""
self.logger.info("[%s] %s", self.name, message)

def log_error(self, message: str):
"""
Log an error message.

Args:
message (str): The error message to log.
"""
self.logger.error("[%s] %s", self.name, message)
36 changes: 36 additions & 0 deletions ai/gen-ai-agents/travel_agent/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
"""
General configuration options
"""

#
# application configs
#
DEBUG = False

# this is the list of the mandatory fields in user input
# if any of these fields is missing, the agent will ask for clarification
REQUIRED_FIELDS = [
"place_of_departure",
"destination",
"start_date",
"end_date",
"num_persons",
"transport_type",
]

# OCI GenAI services configuration
REGION = "eu-frankfurt-1"
SERVICE_ENDPOINT = f"https://inference.generativeai.{REGION}.oci.oraclecloud.com"

# seems to work with both models
MODEL_ID = "meta.llama-3.3-70b-instruct"
# MODEL_ID = "cohere.command-a-03-2025"

MAX_TOKENS = 2048

# Mock API configuration
HOTEL_API_URL = "http://localhost:8000/search/hotels"
TRANSPORT_API_URL = "http://localhost:8000/search/transport"

# Hotel Map
MAP_STYLE = "https://basemaps.cartocdn.com/gl/positron-gl-style/style.json"
111 changes: 111 additions & 0 deletions ai/gen-ai-agents/travel_agent/mock_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
"""
mock_api.py

A simplified mock FastAPI server with two endpoints:
- /search/transport
- /search/hotels
"""

from fastapi import FastAPI, Query
from fastapi.responses import JSONResponse

app = FastAPI()


@app.get("/search/transport")
def search_transport(
destination: str = Query(...),
start_date: str = Query(...),
transport_type: str = Query(...),
):
"""
Mock endpoint to simulate transport search.
Args:
destination (str): Destination city.
start_date (str): Start date of the trip in 'YYYY-MM-DD' format.
transport_type (str): Type of transport (e.g., "airplane", "train").
Returns:
JSONResponse: Mocked transport options.
"""
return JSONResponse(
content={
"options": [
{
"provider": (
"TrainItalia" if transport_type == "train" else "Ryanair"
),
"price": 45.50,
"departure": f"{start_date}T09:00",
"arrival": f"{start_date}T13:00",
"type": transport_type,
}
]
}
)


@app.get("/search/hotels")
def search_hotels(destination: str = Query(...), stars: int = Query(3)):
"""
Mock endpoint to simulate hotel search.
Args:
destination (str): Destination city.
stars (int): Number of stars for hotel preference.
Returns:
JSONResponse: Mocked hotel options.
"""
hotels_by_city = {
"valencia": {
"name": "Hotel Vincci Lys",
"price": 135.0,
"stars": stars,
"location": "Central district",
"amenities": ["WiFi", "Breakfast"],
"latitude": 39.4702,
"longitude": -0.3750,
},
"barcelona": {
"name": "Hotel Jazz",
"price": 160.0,
"stars": stars,
"location": "Eixample",
"amenities": ["WiFi", "Rooftop pool"],
"latitude": 41.3849,
"longitude": 2.1675,
},
"madrid": {
"name": "Only YOU Hotel Atocha",
"price": 170.0,
"stars": stars,
"location": "Retiro",
"amenities": ["WiFi", "Gym", "Restaurant"],
"latitude": 40.4093,
"longitude": -3.6828,
},
"florence": {
"name": "Hotel L'Orologio Firenze",
"price": 185.0,
"stars": stars,
"location": "Santa Maria Novella",
"amenities": ["WiFi", "Spa", "Bar"],
"latitude": 43.7760,
"longitude": 11.2486,
},
"amsterdam": {
"name": "INK Hotel Amsterdam",
"price": 190.0,
"stars": stars,
"location": "City Center",
"amenities": ["WiFi", "Breakfast", "Bar"],
"latitude": 52.3745,
"longitude": 4.8901,
},
}

hotel_key = destination.strip().lower()
hotel = hotels_by_city.get(hotel_key)

if not hotel:
return JSONResponse(content={"hotels": []}, status_code=404)

return JSONResponse(content={"hotels": [hotel]})
29 changes: 29 additions & 0 deletions ai/gen-ai-agents/travel_agent/model_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
"""
Factory for Chat models
"""

from langchain_community.chat_models import ChatOCIGenAI

from config import MODEL_ID, SERVICE_ENDPOINT
from config_private import COMPARTMENT_OCID


def get_chat_model(
model_id: str = MODEL_ID,
service_endpoint: str = SERVICE_ENDPOINT,
temperature=0,
max_tokens=2048,
) -> ChatOCIGenAI:
"""
Factory function to create and return a ChatOCIGenAI model instance.

Returns:
ChatOCIGenAI: Configured chat model instance.
"""
# Create and return the chat model
return ChatOCIGenAI(
model_id=model_id,
service_endpoint=service_endpoint,
model_kwargs={"temperature": temperature, "max_tokens": max_tokens},
compartment_id=COMPARTMENT_OCID,
)
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
76 changes: 76 additions & 0 deletions ai/gen-ai-agents/travel_agent/nodes/answer_info_node.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# answer_info_node.py
# -*- coding: utf-8 -*-
"""
AnswerInfoNode

This module defines the AnswerInfoNode class, which is responsible for handling
general travel information queries within the LangGraph-based travel assistant.

When a user request is classified as an "info" intent by the router node,
this node generates a markdown-formatted response using a language model.

Author: L. Saetta
Date: 20/05/2025

"""
from langchain_core.runnables import Runnable
from langchain_core.output_parsers import StrOutputParser
from base_node import BaseNode
from model_factory import get_chat_model
from prompt_template import answer_prompt
from config import MODEL_ID, SERVICE_ENDPOINT, MAX_TOKENS, DEBUG


class AnswerInfoNode(BaseNode):
"""
Node in the LangGraph workflow responsible for handling general travel information queries.

This node is used when the user's intent is classified as an information request
(rather than a booking).
It uses a language model to generate a helpful, markdown-formatted response
based on the user's input.

Attributes:
prompt (PromptTemplate): The prompt template for generating the informational response.
llm (Runnable): The configured language model used for generation.
chain (Runnable): Composed chain of prompt → model → output parser.
"""

def __init__(self):
"""
Initialize the AnswerInfoNode with a pre-defined prompt and LLM configuration.

The chain is constructed from:
- `answer_prompt` (PromptTemplate)
- A chat model initialized via `get_chat_model`
- A `StrOutputParser` for plain string output
"""
super().__init__("answer_info")

self.prompt = answer_prompt
self.llm = get_chat_model(
model_id=MODEL_ID,
service_endpoint=SERVICE_ENDPOINT,
temperature=0.5,
max_tokens=MAX_TOKENS,
)
self.chain: Runnable = self.prompt | self.llm | StrOutputParser()

def invoke(self, state, config=None, **kwargs):
"""
Generate a general travel information response from the user's question.

Args:
state (dict): The current LangGraph state, which must include a 'user_input' key.
config (optional): Reserved for compatibility; not used.

Returns:
dict: Updated state with the 'final_plan' field set to the LLM-generated response.
"""
response = self.chain.invoke({"user_input": state["user_input"]}).strip()

if DEBUG:
self.log_info("Generated informational response.")

state["final_plan"] = response
return state
Loading