|
| 1 | +# genai/src/services/llm/llm_service.py |
1 | 2 | import os |
2 | | -from langchain_openai import ChatOpenAI |
3 | 3 | import json |
4 | 4 | import logging |
5 | | -from langchain_community.llms import FakeListLLM |
6 | | -from langchain_core.language_models.base import BaseLanguageModel |
7 | 5 | from typing import List, Type, TypeVar |
| 6 | + |
8 | 7 | from pydantic import BaseModel, ValidationError |
| 8 | +from langchain_openai import ChatOpenAI |
| 9 | +from langchain_community.llms import FakeListLLM |
| 10 | +from langchain_core.language_models.base import BaseLanguageModel |
9 | 11 |
|
10 | 12 | logger = logging.getLogger(__name__) |
| 13 | + |
11 | 14 | T = TypeVar("T", bound=BaseModel) |
12 | 15 |
|
| 16 | +# ────────────────────────────────────────────────────────────────────────── |
| 17 | +# LLM factory |
| 18 | +# ────────────────────────────────────────────────────────────────────────── |
13 | 19 |
|
14 | 20 | def llm_factory() -> BaseLanguageModel: |
15 | | - """ |
16 | | - Factory function to create and return an LLM instance based on the provider |
17 | | - specified in the environment variables. |
18 | | - Supports OpenAI, OpenAI-compatible (local/llmstudio), and dummy models. |
19 | | - """ |
| 21 | + """Return a singleton LangChain LLM according to $LLM_PROVIDER.""" |
20 | 22 | provider = os.getenv("LLM_PROVIDER", "dummy").lower() |
21 | 23 | logger.info(f"--- Creating LLM for provider: {provider} ---") |
22 | 24 |
|
23 | 25 | if provider in ("openai", "llmstudio", "local"): |
24 | | - # Get API base and key from env |
25 | 26 | openai_api_key = os.getenv("OPENAI_API_KEY", "sk-xxx-dummy-key") |
26 | 27 | openai_api_base = os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1") |
27 | | - |
28 | 28 | model = os.getenv("OPENAI_MODEL", "gpt-4o-mini") |
29 | 29 | return ChatOpenAI( |
30 | 30 | model=model, |
31 | 31 | temperature=0.7, |
32 | 32 | openai_api_key=openai_api_key, |
33 | | - openai_api_base=openai_api_base |
| 33 | + openai_api_base=openai_api_base, |
34 | 34 | ) |
35 | | - |
36 | | - elif provider == "dummy": |
| 35 | + |
| 36 | + if provider == "dummy": |
37 | 37 | responses = [ |
38 | 38 | "The first summary from the dummy LLM is about procedural languages.", |
39 | 39 | "The second summary is about object-oriented programming.", |
40 | 40 | "This is a fallback response.", |
41 | 41 | ] |
42 | 42 | return FakeListLLM(responses=responses) |
43 | 43 |
|
44 | | - else: |
45 | | - raise ValueError(f"Currently Unsupported LLM provider: {provider}") |
| 44 | + raise ValueError(f"Unsupported LLM provider: {provider}") |
| 45 | + |
46 | 46 |
|
47 | 47 | LLM_SINGLETON = llm_factory() |
48 | 48 |
|
| 49 | +# ────────────────────────────────────────────────────────────────────────── |
| 50 | +# Convenience helpers |
| 51 | +# ────────────────────────────────────────────────────────────────────────── |
| 52 | + |
49 | 53 | def generate_text(prompt: str) -> str: |
50 | | - """ |
51 | | - Generates a text completion for a given prompt using the configured LLM. |
52 | | - """ |
53 | | - # 1. Get the correct LLM instance from our factory |
| 54 | + """Simple text completion (legacy helper).""" |
54 | 55 | llm = LLM_SINGLETON |
55 | 56 |
|
56 | | - # if we using local LLM, we need to append "/no_think" in case the model is a thinking model |
57 | | - if os.getenv("LLM_PROVIDER", "dummy").lower() == "llmstudio" and hasattr(llm, 'model_name'): |
58 | | - prompt += "/no_think" |
59 | | - |
60 | | - # 2. Invoke the LLM with the prompt |
| 57 | + if os.getenv("LLM_PROVIDER", "dummy").lower() == "llmstudio" and hasattr(llm, "model_name"): |
| 58 | + prompt += "/no_think" |
| 59 | + |
61 | 60 | response = llm.invoke(prompt) |
| 61 | + return response.content if hasattr(response, "content") else response |
| 62 | + |
| 63 | + |
| 64 | +def generate_structured(messages: List[dict], schema: Type[T], *, max_retries: int = 3) -> T: |
| 65 | + """Return a Pydantic object *schema* regardless of the underlying provider. |
62 | 66 |
|
63 | | - # 3. The response object's structure can vary slightly by model. |
64 | | - # For Chat models, the text is in the .content attribute. |
65 | | - # For standard LLMs (like our FakeListLLM), it's the string itself. |
66 | | - if hasattr(response, 'content'): |
67 | | - return response.content |
68 | | - else: |
69 | | - return response |
70 | | - |
71 | | - |
72 | | -def generate_structured( |
73 | | - messages: List[dict], |
74 | | - schema: Type[T], |
75 | | - *, |
76 | | - max_retries: int = 3, |
77 | | -) -> T: |
78 | | - """Return a Pydantic object regardless of provider (OpenAI JSON-mode or fallback).""" |
| 67 | + 1. For $LLM_PROVIDER==openai we use the native `beta.chat.completions.parse` API. |
| 68 | + 2. Otherwise we fall back to strict JSON prompting and `model_validate_json()`. |
| 69 | + """ |
79 | 70 | provider = os.getenv("LLM_PROVIDER", "dummy").lower() |
80 | 71 |
|
81 | | - # 1) OpenAI native JSON mode |
| 72 | + # ── 1. Native OpenAI JSON mode ─────────────────────────────────────── |
82 | 73 | if provider == "openai": |
83 | 74 | try: |
84 | | - from openai import OpenAI |
| 75 | + from openai import OpenAI # local import to avoid hard dep for other providers |
| 76 | + |
85 | 77 | client = OpenAI( |
86 | 78 | api_key=os.getenv("OPENAI_API_KEY"), |
87 | 79 | base_url=os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1"), |
88 | 80 | ) |
89 | | - resp = client.beta.chat.completions.parse( |
| 81 | + response = client.beta.chat.completions.parse( |
90 | 82 | model=os.getenv("OPENAI_MODEL", "gpt-4o-mini"), |
91 | 83 | messages=messages, |
92 | 84 | response_format=schema, |
93 | 85 | ) |
94 | | - return resp.choices[0].message.parsed # type: ignore[arg-type] |
| 86 | + return response.choices[0].message.parsed # type: ignore[arg-type] |
95 | 87 | except Exception as e: |
96 | 88 | logger.warning(f"OpenAI structured parse failed – falling back: {e}") |
97 | 89 |
|
98 | | - # 2) Generic JSON-string fallback |
| 90 | + # ── 2. Generic JSON-string fallback for all other models ───────────── |
99 | 91 | system_json_guard = { |
100 | 92 | "role": "system", |
101 | 93 | "content": ( |
102 | | - "Return ONLY valid JSON matching this schema:\n" |
103 | | - + json.dumps(schema.model_json_schema()) |
| 94 | + "You are a JSON-only assistant. Produce **only** valid JSON that conforms to " |
| 95 | + "this schema (no markdown, no explanations):\n" + json.dumps(schema.model_json_schema()) |
104 | 96 | ), |
105 | 97 | } |
| 98 | + |
106 | 99 | convo: List[dict] = [system_json_guard] + messages |
107 | | - llm = LLM_SINGLETON |
108 | 100 |
|
| 101 | + llm = LLM_SINGLETON |
109 | 102 | for attempt in range(1, max_retries + 1): |
110 | 103 | raw = llm.invoke(convo) |
111 | | - text = raw.content if hasattr(raw, "content") else raw |
| 104 | + text = raw.content if hasattr(raw, "content") else raw # Chat vs non-chat |
112 | 105 | try: |
113 | 106 | return schema.model_validate_json(text) |
114 | 107 | except ValidationError as e: |
115 | 108 | logger.warning( |
116 | | - f"Structured output validation failed ({attempt}/{max_retries}): {e}" |
| 109 | + f"Structured output validation failed (try {attempt}/{max_retries}): {e}"\ |
117 | 110 | ) |
118 | | - convo += [ |
119 | | - {"role": "assistant", "content": text}, |
120 | | - { |
121 | | - "role": "user", |
122 | | - "content": "❌ JSON invalid. Send ONLY fixed JSON.", |
123 | | - }, |
124 | | - ] |
125 | | - |
126 | | - raise ValueError("Could not obtain valid structured output") |
| 111 | + convo.append({"role": "assistant", "content": text}) |
| 112 | + convo.append({ |
| 113 | + "role": "user", |
| 114 | + "content": ( |
| 115 | + "❌ JSON was invalid: " + str(e.errors()) + |
| 116 | + "\nPlease resend ONLY the corrected JSON (no extraneous text)." |
| 117 | + ), |
| 118 | + }) |
| 119 | + |
| 120 | + raise ValueError("Failed to get valid structured output after retries") |
0 commit comments