Skip to content

Commit 89d01c8

Browse files
authored
Merge pull request #64 from AET-DevOps25/37-genai-rag-learning-path
37 genai rag learning path Closes #37
2 parents 3f70a9e + 5432350 commit 89d01c8

File tree

7 files changed

+339
-21
lines changed

7 files changed

+339
-21
lines changed

genai/src/main.py

Lines changed: 29 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -19,23 +19,25 @@
1919
from .services.embedding.schemas import EmbedRequest, EmbedResponse, QueryRequest, QueryResponse, DocumentResult
2020
from .services.embedding.weaviate_service import get_weaviate_client, ensure_schema_exists, DOCUMENT_CLASS_NAME
2121
from .services.llm import llm_service
22-
from .services.llm.schemas import GenerateRequest, GenerateResponse
23-
from .utils.error_schema import ErrorResponse
22+
from .services.llm.schemas import GenerateRequest, GenerateResponse
23+
from .services.rag.schemas import CourseGenerationRequest, Course
24+
from .services.rag import course_generator
25+
from .utils.error_schema import ErrorResponse
2426
from .utils.handle_httpx_exception import handle_httpx_exception
2527

26-
2728
# --- Configuration ---
2829
load_dotenv()
2930
logger = logging.getLogger("skillforge.genai")
3031

3132
APP_PORT = int(os.getenv("GENAI_PORT", "8082"))
3233
APP_TITLE = os.getenv("GENAI_APP_NAME", "SkillForge GenAI Service")
3334
APP_VERSION = os.getenv("GENAI_APP_VERSION", "0.0.1")
34-
APP_DESCRIPTION = (
35-
"SkillForge GenAI Service provides endpoints for web crawling, "
36-
"chunking, embedding, semantic querying, and text generation using LLMs. "
37-
"Ideal for integrating vector search and AI-driven workflows."
38-
)
35+
APP_DESCRIPTION = (
36+
"SkillForge GenAI Service provides endpoints for web crawling, "
37+
"chunking, embedding, semantic querying, and text generation using LLMs. "
38+
"Ideal for integrating vector search and AI-driven workflows."
39+
)
40+
API_PREFIX = "/api/v1"
3941
TAGS_METADATA = [
4042
{"name": "System", "description": "Health checks and system status."},
4143
{"name": "Crawler", "description": "Crawl and clean website content."},
@@ -110,7 +112,7 @@ async def unhandled_exception_handler(request: Request, exc: Exception):
110112

111113
# ---- System Endpoints --------
112114
# -------------------------------
113-
@app.get("/health", tags=["System"])
115+
@app.get(f"{API_PREFIX}/health", tags=["System"])
114116
async def health():
115117
"""
116118
Deep health check. Verifies the application and its core dependencies (e.g., DB, vector store).
@@ -126,7 +128,7 @@ async def health():
126128
content={"status": "error", "message": "Dependency failure. See logs for details."}
127129
)
128130

129-
@app.get("/ping", tags=["System"])
131+
@app.get(f"{API_PREFIX}/ping", tags=["System"])
130132
async def ping():
131133
"""
132134
Lightweight liveness check. Confirms the API process is running, but does not check dependencies.
@@ -139,7 +141,7 @@ async def ping():
139141
# -------------------------------
140142
# ----- Crawler endpoints -----
141143
# -------------------------------
142-
@app.post("/crawl", response_model=CrawlResponse, responses={400: {"model": ErrorResponse}, 500: {"model": ErrorResponse}}, tags=["Crawler"])
144+
@app.post(f"{API_PREFIX}/crawl", response_model=CrawlResponse, responses={400: {"model": ErrorResponse}, 500: {"model": ErrorResponse}}, tags=["Crawler"])
143145
async def crawl(request: CrawlRequest):
144146
url = str(request.url)
145147
try:
@@ -174,7 +176,7 @@ async def crawl(request: CrawlRequest):
174176
# -------------------------------
175177
# ----- Vector DB endpoints -----
176178
# -------------------------------
177-
@app.post("/embed", response_model=EmbedResponse, tags=["Embedder"])
179+
@app.post(f"{API_PREFIX}/embed", response_model=EmbedResponse, tags=["Embedder"])
178180
async def embed_url(request: EmbedRequest):
179181
"""Orchestrates the full workflow: Crawl -> Chunk -> Embed -> Store."""
180182
url_str = str(request.url)
@@ -209,7 +211,7 @@ async def embed_url(request: EmbedRequest):
209211

210212

211213

212-
@app.post("/query", response_model=QueryResponse)
214+
@app.post(f"{API_PREFIX}/query", response_model=QueryResponse)
213215
async def query_vector_db(request: QueryRequest):
214216
"""Queries the vector database for text chunks semantically similar to the query."""
215217
client = get_weaviate_client()
@@ -231,7 +233,7 @@ async def query_vector_db(request: QueryRequest):
231233
# -------------------------------
232234
# --- LLM Endpoints -------------
233235
# -------------------------------
234-
@app.post("/generate", response_model=GenerateResponse, tags=["LLM"])
236+
@app.post(f"{API_PREFIX}/generate", response_model=GenerateResponse, tags=["LLM"])
235237
async def generate_completion(request: GenerateRequest):
236238
"""Generates a text completion using the configured LLM abstraction layer."""
237239
try:
@@ -245,7 +247,19 @@ async def generate_completion(request: GenerateRequest):
245247
logging.error(f"ERROR during text generation: {e}")
246248
raise HTTPException(status_code=500, detail=f"Failed to generate text: {str(e)}")
247249

248-
250+
# ──────────────────────────────────────────────────────────────────────────
251+
# NEW – main RAG endpoint
252+
# ──────────────────────────────────────────────────────────────────────────
253+
@app.post(f"{API_PREFIX}/rag/generate-course", response_model=Course, tags=["rag"])
254+
async def generate_course(req: CourseGenerationRequest):
255+
"""
256+
• POST because generation is a side-effectful operation (non-idempotent).
257+
• Returns a fully-validated Course JSON ready for the course-service.
258+
"""
259+
try:
260+
return course_generator.generate_course(req)
261+
except Exception as e:
262+
raise HTTPException(500, str(e)) from e
249263

250264
# -------------------------------
251265
# --------- MAIN ----------------

genai/src/services/embedding/embedder_service.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44
from langchain_community.vectorstores.weaviate import Weaviate
55
from .weaviate_service import get_weaviate_client, DOCUMENT_CLASS_NAME
66
import logging
7+
from typing import List
8+
import numpy as np
9+
from .schemas import QueryResponse, QueryRequest, DocumentResult
710

811
logger = logging.getLogger("skillforge.genai.embedder_service")
912

@@ -42,4 +45,34 @@ def embed_and_store_text(text: str, source_url: str) -> int:
4245
else:
4346
logger.info(f"Stored {num_chunks} chunks in Weaviate for URL {source_url}.")
4447

45-
return num_chunks
48+
return num_chunks
49+
50+
_embeddings_model = OpenAIEmbeddings(model="text-embedding-3-small")
51+
52+
def embed_text(text: str) -> List[float]:
53+
"""Generate a single embedding vector from raw text."""
54+
return _embeddings_model.embed_query(text)
55+
56+
def cosine_similarity(v1: List[float], v2: List[float]) -> float:
57+
"""Simple cosine similarity between two vectors."""
58+
a = np.array(v1)
59+
b = np.array(v2)
60+
return float(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b)))
61+
62+
def query_similar_chunks(query_text: str, limit: int = 3) -> QueryResponse:
63+
"""
64+
Stateless helper – identical logic to the /query endpoint but callable in-process.
65+
"""
66+
client = get_weaviate_client()
67+
embeddings_model = OpenAIEmbeddings(model="text-embedding-3-small")
68+
vector = embeddings_model.embed_query(query_text)
69+
70+
result = (
71+
client.query
72+
.get(DOCUMENT_CLASS_NAME, ["content", "source_url"])
73+
.with_near_vector({"vector": vector})
74+
.with_limit(limit)
75+
.do()
76+
)
77+
docs = [DocumentResult(**d) for d in result["data"]["Get"][DOCUMENT_CLASS_NAME]]
78+
return QueryResponse(query=query_text, results=docs)

genai/src/services/embedding/schemas.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# genai/src/services/embedding/schemas.py
21
from pydantic import BaseModel, HttpUrl
32
from typing import List, Optional
43

genai/src/services/embedding/weaviate_service.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
DOCUMENT_CLASS_NAME = "DocumentChunk"
77

88
WEAVIATE_HOST = os.getenv("WEAVIATE_HOST", "localhost")
9-
WEAVIATE_HTTP_PORT = int(os.getenv("WEAVIATE_HTTP_PORT", "1234"))
9+
WEAVIATE_HTTP_PORT = int(os.getenv("WEAVIATE_HTTP_PORT", "8080"))
1010
WEAVIATE_GRPC_PORT = int(os.getenv("WEAVIATE_GRPC_PORT", "50051"))
1111

1212
def get_weaviate_client() -> weaviate.Client:

genai/src/services/llm/llm_service.py

Lines changed: 67 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,15 @@
11
import os
2-
import logging
32
from langchain_openai import ChatOpenAI
3+
import json
4+
import logging
45
from langchain_community.llms import FakeListLLM
56
from langchain_core.language_models.base import BaseLanguageModel
7+
from typing import List, Type, TypeVar
8+
from pydantic import BaseModel, ValidationError
9+
10+
logger = logging.getLogger(__name__)
11+
T = TypeVar("T", bound=BaseModel)
12+
613

714
def llm_factory() -> BaseLanguageModel:
815
"""
@@ -11,7 +18,7 @@ def llm_factory() -> BaseLanguageModel:
1118
Supports OpenAI, OpenAI-compatible (local/llmstudio), and dummy models.
1219
"""
1320
provider = os.getenv("LLM_PROVIDER", "dummy").lower()
14-
logging.info(f"--- Creating LLM for provider: {provider} ---")
21+
logger.info(f"--- Creating LLM for provider: {provider} ---")
1522

1623
if provider in ("openai", "llmstudio", "local"):
1724
# Get API base and key from env
@@ -59,4 +66,61 @@ def generate_text(prompt: str) -> str:
5966
if hasattr(response, 'content'):
6067
return response.content
6168
else:
62-
return response
69+
return response
70+
71+
72+
def generate_structured(
73+
messages: List[dict],
74+
schema: Type[T],
75+
*,
76+
max_retries: int = 3,
77+
) -> T:
78+
"""Return a Pydantic object regardless of provider (OpenAI JSON-mode or fallback)."""
79+
provider = os.getenv("LLM_PROVIDER", "dummy").lower()
80+
81+
# 1) OpenAI native JSON mode
82+
if provider == "openai":
83+
try:
84+
from openai import OpenAI
85+
client = OpenAI(
86+
api_key=os.getenv("OPENAI_API_KEY"),
87+
base_url=os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1"),
88+
)
89+
resp = client.beta.chat.completions.parse(
90+
model=os.getenv("OPENAI_MODEL", "gpt-4o-mini"),
91+
messages=messages,
92+
response_format=schema,
93+
)
94+
return resp.choices[0].message.parsed # type: ignore[arg-type]
95+
except Exception as e:
96+
logger.warning(f"OpenAI structured parse failed – falling back: {e}")
97+
98+
# 2) Generic JSON-string fallback
99+
system_json_guard = {
100+
"role": "system",
101+
"content": (
102+
"Return ONLY valid JSON matching this schema:\n"
103+
+ json.dumps(schema.model_json_schema())
104+
),
105+
}
106+
convo: List[dict] = [system_json_guard] + messages
107+
llm = LLM_SINGLETON
108+
109+
for attempt in range(1, max_retries + 1):
110+
raw = llm.invoke(convo)
111+
text = raw.content if hasattr(raw, "content") else raw
112+
try:
113+
return schema.model_validate_json(text)
114+
except ValidationError as e:
115+
logger.warning(
116+
f"Structured output validation failed ({attempt}/{max_retries}): {e}"
117+
)
118+
convo += [
119+
{"role": "assistant", "content": text},
120+
{
121+
"role": "user",
122+
"content": "❌ JSON invalid. Send ONLY fixed JSON.",
123+
},
124+
]
125+
126+
raise ValueError("Could not obtain valid structured output")

0 commit comments

Comments
 (0)