Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 117 additions & 8 deletions backend/nlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@

from dotenv import load_dotenv
from pydantic import BaseModel
from pathlib import Path

from langfuse import Langfuse, observe, get_client

import openai
from openai.types.chat.chat_completion_message_param import ChatCompletionMessageParam
Expand All @@ -20,15 +23,36 @@
}
DEBUG_PROMPTS = False

load_dotenv()
# Load .env
env_path = Path(__file__).parent / ".env"
load_dotenv(dotenv_path=env_path)
openai_api_key = (os.getenv("OPENAI_API_KEY") or "").strip()

if openai_api_key == "":
raise Exception("OPENAI_API_KEY is not set. Please set it in a .env file.")

# Validate Langfuse configuration
langfuse_public_key = os.getenv("LANGFUSE_PUBLIC_KEY")
langfuse_secret_key = os.getenv("LANGFUSE_SECRET_KEY")
langfuse_base_url = os.getenv("LANGFUSE_BASE_URL")

if not langfuse_public_key or not langfuse_secret_key:
raise Exception(
"LANGFUSE_PUBLIC_KEY and LANGFUSE_SECRET_KEY must be set in .env file. "
f"Current values: PUBLIC_KEY={'set' if langfuse_public_key else 'MISSING'}, "
f"SECRET_KEY={'set' if langfuse_secret_key else 'MISSING'}, "
f"BASE_URL={langfuse_base_url or 'MISSING'}"
)

openai_client = AsyncOpenAI(
api_key=openai_api_key,
)

langfuse = Langfuse(
public_key=langfuse_public_key,
secret_key=langfuse_secret_key,
host=langfuse_base_url
)


async def warmup_nlp():
Expand Down Expand Up @@ -157,15 +181,29 @@ class ListResponse(BaseModel):
responses: List[str]


@observe(name="llm_parse_context", as_type="generation")
async def _get_suggestions_from_context(
prompt_name: str, doc_context: DocContext, use_false_context: bool = False
prompt_name: str,
doc_context: DocContext,
use_false_context: bool = False
) -> List[str]:
"""Helper function to get suggestions from a specific context"""
context_type = "false" if use_false_context else "true"

# Update current observation with metadata (v3 pattern because of langfuse_decorator version issues)
langfuse = get_client()
langfuse.update_current_observation(
metadata={
"prompt_name": prompt_name,
"context_type": context_type,
"use_false_context": use_false_context
}
)

full_prompt = get_full_prompt(
prompt_name, doc_context, use_false_context=use_false_context
)
if DEBUG_PROMPTS:
context_type = "false" if use_false_context else "true"
print(f"Prompt for {prompt_name} ({context_type} context):\n{full_prompt}\n")

completion = await openai_client.chat.completions.parse(
Expand All @@ -187,12 +225,34 @@ async def _get_suggestions_from_context(
return suggestion_response.responses


@observe(name="get_suggestion")
async def get_suggestion(prompt_name: str, doc_context: DocContext) -> GenerationResult:
"""
Main function to get suggestions with Langfuse tracing.
This creates a trace for each suggestion request.
"""
# Update trace with metadata for filtering in Langfuse (v3 pattern)
langfuse = get_client()

langfuse.update_current_trace(
name=f"suggestion_{prompt_name}",
metadata={
"suggestion_type": prompt_name, # Primary field for evaluation filtering
"has_false_context": doc_context.falseContextData is not None and len(doc_context.falseContextData) > 0,
"has_true_context": doc_context.contextData is not None and len(doc_context.contextData) > 0,
"document_length": len(doc_context.beforeCursor + doc_context.selectedText + doc_context.afterCursor),
"has_selection": len(doc_context.selectedText) > 0,
"model": MODEL_PARAMS["model"]
},
tags=[prompt_name, "suggestion"],
session_id=prompt_name # Groups all traces of same type together
)
# Special handling for complete_document: always use false context only, plain completion
if prompt_name == "complete_document":
full_prompt = get_full_prompt(prompt_name, doc_context, use_false_context=True)
if DEBUG_PROMPTS:
print(f"Prompt for {prompt_name} (false context only):\n{full_prompt}\n")

completion = await openai_client.chat.completions.create(
**MODEL_PARAMS,
messages=[
Expand All @@ -204,13 +264,23 @@ async def get_suggestion(prompt_name: str, doc_context: DocContext) -> Generatio
result = completion.choices[0].message.content
if not result:
raise ValueError("No response found from complete_document.")

langfuse.update_current_trace(
output={"result": result}
)

return GenerationResult(generation_type=prompt_name, result=result, extra_data={})

# If falseContextData is None/empty, use baseline behavior
if not doc_context.falseContextData:
langfuse.update_current_trace(
metadata={"execution_mode": "baseline"}
)

full_prompt = get_full_prompt(prompt_name, doc_context)
if DEBUG_PROMPTS:
print(f"Prompt for {prompt_name} (baseline):\n{full_prompt}\n")

completion = await openai_client.chat.completions.parse(
**MODEL_PARAMS,
messages=[
Expand All @@ -224,16 +294,29 @@ async def get_suggestion(prompt_name: str, doc_context: DocContext) -> Generatio
)

suggestion_response = completion.choices[0].message.parsed
if not suggestion_response or not suggestion_response:
if not suggestion_response or not suggestion_response.responses:
raise ValueError("No suggestions found in the response.")

markdown_response = "\n\n".join(
[f"- {item}" for item in suggestion_response.responses]
)

langfuse.update_current_trace(
output={
"result": markdown_response,
"suggestions": suggestion_response.responses
}
)

return GenerationResult(
generation_type=prompt_name, result=markdown_response, extra_data={}
)

# Study mode: parallel calls with mixing
langfuse.update_current_trace(
metadata={"execution_mode": "study_mode_with_mixing"}
)

true_suggestions_task = _get_suggestions_from_context(
prompt_name, doc_context, use_false_context=False
)
Expand All @@ -248,6 +331,9 @@ async def get_suggestion(prompt_name: str, doc_context: DocContext) -> Generatio

if len(true_suggestions) == 0 or len(false_suggestions) == 0:
# One or both of the queries refused.
langfuse.update_current_trace(
metadata={"refusal": True}
)
return GenerationResult(generation_type=prompt_name, result="", extra_data={})

if len(true_suggestions) != 3 or len(false_suggestions) != 3:
Expand All @@ -274,7 +360,7 @@ async def get_suggestion(prompt_name: str, doc_context: DocContext) -> Generatio
request_hash = hashlib.sha256(
json.dumps(request_body, sort_keys=True).encode()
).hexdigest()
shuffle_seed = int(request_hash[:8], 16) # Use first 8 hex chars as seed
shuffle_seed = int(request_hash[:8], 16) # Use first 8 hex digits as seed

# Combine and shuffle suggestions
all_suggestions = []
Expand Down Expand Up @@ -321,6 +407,22 @@ async def get_suggestion(prompt_name: str, doc_context: DocContext) -> Generatio
1 for item in selected_suggestions if item["source"] == "false"
),
}

# Add mixing results to trace metadata
langfuse.update_current_trace(
metadata={
"mixing_stats": {
"true_count": extra_data["total_true_suggestions"],
"false_count": extra_data["total_false_suggestions"],
"shuffle_seed": shuffle_seed
}
},
output={
"result": markdown_response,
"suggestions": [item["content"] for item in selected_suggestions],
"sources": [item["source"] for item in selected_suggestions]
}
)

return GenerationResult(
generation_type=prompt_name, result=markdown_response, extra_data=extra_data
Expand All @@ -342,8 +444,6 @@ async def chat(
)

result = response.choices[0].message.content

# FIXME: figure out why result might ever be None
return result or ""


Expand All @@ -356,8 +456,17 @@ def chat_stream(messages: Iterable[ChatCompletionMessageParam], temperature: flo
)


@observe(name="reflection", as_type="generation")
async def reflection(userDoc: str, paragraph: str) -> GenerationResult:
temperature = 1.0

langfuse = get_client()
langfuse.update_current_observation(
metadata={
"temperature": temperature,
"generation_type": "reflection"
}
)

questions = await chat(
messages=[
Expand All @@ -374,4 +483,4 @@ async def reflection(userDoc: str, paragraph: str) -> GenerationResult:
"prompt": userDoc,
"temperature": temperature,
},
)
)