Skip to content

Commit 408acf8

Browse files
committed
fix(lm): implement OpenAI and Ollama LM adapters; support LM_PROVIDER; resolve abstract LM error
1 parent c1b26b4 commit 408acf8

File tree

2 files changed

+149
-11
lines changed

2 files changed

+149
-11
lines changed

orbit_agent/config.py

Lines changed: 43 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ def __post_init__(self):
9393
def _determine_model_and_key() -> tuple[str, Optional[str]]:
9494
"""Determine which model and API key to use"""
9595
explicit_model = os.getenv("ORBIT_LM")
96+
provider_hint = os.getenv("LM_PROVIDER") # Optional compatibility env: openai|anthropic|ollama
9697
openai_key = os.getenv("OPENAI_API_KEY")
9798
anthropic_key = os.getenv("ANTHROPIC_API_KEY")
9899

@@ -111,13 +112,26 @@ def _determine_model_and_key() -> tuple[str, Optional[str]]:
111112
else:
112113
return explicit_model, None
113114

114-
# Auto-detect based on available keys
115+
# Auto-detect based on available keys or provider hint
116+
if provider_hint:
117+
hint = provider_hint.lower()
118+
if hint.startswith("ollama"):
119+
return DEFAULT_OLLAMA, None
120+
if hint.startswith("openai"):
121+
if not openai_key:
122+
raise ValueError("LM_PROVIDER=openai requires OPENAI_API_KEY")
123+
return DEFAULT_OPENAI, openai_key
124+
if hint.startswith("anthropic"):
125+
if not anthropic_key:
126+
raise ValueError("LM_PROVIDER=anthropic requires ANTHROPIC_API_KEY")
127+
return DEFAULT_ANTHROPIC, anthropic_key
128+
115129
if openai_key:
116130
return DEFAULT_OPENAI, openai_key
117131
elif anthropic_key:
118132
return DEFAULT_ANTHROPIC, anthropic_key
119133
else:
120-
logger.info("No API keys found, defaulting to Ollama")
134+
logger.info("No API keys found and no provider hint, defaulting to Ollama")
121135
return DEFAULT_OLLAMA, None
122136

123137

@@ -185,15 +199,33 @@ def configure_lm() -> AppConfig:
185199
config = get_config()
186200

187201
try:
188-
# Use generic LM wrapper; provider inferred from model prefix (openai/, anthropic/, ollama_*)
189-
# API keys are read from environment by provider integrations.
190-
lm = dspy.LM(
191-
model=config.lm.model,
192-
temperature=config.lm.temperature,
193-
max_tokens=config.lm.max_tokens,
194-
)
195-
dspy.configure(lm=lm)
196-
logger.info(f"Configured LM via dspy.LM: {config.lm.model}")
202+
# Use custom lightweight LM adapters compatible with dsp/dspy Predict
203+
from .lm_providers import OpenAIChatLM, OllamaLM
204+
205+
model = config.lm.model
206+
lm_impl = None
207+
if model.startswith("openai/"):
208+
mname = model.replace("openai/", "")
209+
lm_impl = OpenAIChatLM(
210+
model=mname,
211+
api_key=config.lm.api_key or os.getenv("OPENAI_API_KEY"),
212+
temperature=config.lm.temperature,
213+
max_tokens=config.lm.max_tokens,
214+
api_base=os.getenv("OPENAI_BASE_URL"),
215+
)
216+
elif model.startswith("ollama_chat/") or model.startswith("ollama/"):
217+
mname = model.replace("ollama_chat/", "").replace("ollama/", "")
218+
lm_impl = OllamaLM(
219+
model=mname,
220+
base_url=config.lm.api_base or os.getenv("OLLAMA_API_BASE", "http://localhost:11434"),
221+
temperature=config.lm.temperature,
222+
max_tokens=config.lm.max_tokens,
223+
)
224+
else:
225+
raise ValueError(f"Unsupported model provider for '{model}'. Use openai/ or ollama_chat/.")
226+
227+
dspy.configure(lm=lm_impl)
228+
logger.info(f"Configured LM: {config.lm.model}")
197229
return config
198230

199231
except Exception as e:

orbit_agent/lm_providers.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
from __future__ import annotations
2+
3+
import os
4+
import json
5+
from typing import List, Dict, Any
6+
7+
import requests
8+
from dsp.modules.lm import LM
9+
10+
11+
class OpenAIChatLM(LM):
12+
"""Minimal OpenAI Chat Completions adapter for dsp/dspy LM interface.
13+
14+
Returns a list of completion strings for a given prompt.
15+
"""
16+
17+
def __init__(
18+
self,
19+
model: str,
20+
api_key: str | None,
21+
temperature: float = 0.0,
22+
max_tokens: int = 400,
23+
api_base: str | None = None,
24+
):
25+
super().__init__(model)
26+
self.provider = "openai"
27+
self.api_key = api_key or os.getenv("OPENAI_API_KEY")
28+
self.api_base = api_base or os.getenv("OPENAI_BASE_URL", "https://api.openai.com/v1")
29+
self.kwargs["temperature"] = temperature
30+
self.kwargs["max_tokens"] = max_tokens
31+
32+
def basic_request(self, prompt: str, **kwargs) -> Dict[str, Any]:
33+
if not self.api_key:
34+
raise ValueError("OPENAI_API_KEY not set")
35+
36+
url = f"{self.api_base}/chat/completions"
37+
payload = {
38+
"model": self.kwargs["model"],
39+
"messages": [
40+
{"role": "system", "content": "You are a concise, direct startup advisor."},
41+
{"role": "user", "content": prompt},
42+
],
43+
"temperature": kwargs.get("temperature", self.kwargs.get("temperature", 0.0)),
44+
"max_tokens": kwargs.get("max_tokens", self.kwargs.get("max_tokens", 400)),
45+
"n": kwargs.get("n", self.kwargs.get("n", 1)),
46+
}
47+
headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
48+
resp = requests.post(url, headers=headers, data=json.dumps(payload), timeout=60)
49+
resp.raise_for_status()
50+
return resp.json()
51+
52+
def __call__(self, prompt: str, only_completed: bool = True, return_sorted: bool = False, **kwargs) -> List[str]:
53+
data = self.basic_request(prompt, **kwargs)
54+
choices = data.get("choices", [])
55+
out: List[str] = []
56+
for ch in choices:
57+
msg = ch.get("message", {})
58+
content = msg.get("content")
59+
if content:
60+
out.append(content)
61+
if not out and "error" in data:
62+
raise RuntimeError(f"OpenAI error: {data['error']}")
63+
if not out:
64+
# Fallback to an empty string to avoid crashes
65+
out = [""]
66+
return out
67+
68+
69+
class OllamaLM(LM):
70+
"""Minimal Ollama generate adapter for dsp/dspy LM interface.
71+
72+
Uses /api/generate (non-streaming) and returns a single completion string.
73+
"""
74+
75+
def __init__(
76+
self,
77+
model: str,
78+
base_url: str = "http://localhost:11434",
79+
temperature: float = 0.0,
80+
max_tokens: int = 400,
81+
):
82+
super().__init__(model)
83+
self.provider = "ollama"
84+
self.base_url = base_url.rstrip("/")
85+
self.kwargs["temperature"] = temperature
86+
self.kwargs["max_tokens"] = max_tokens
87+
88+
def basic_request(self, prompt: str, **kwargs) -> Dict[str, Any]:
89+
url = f"{self.base_url}/api/generate"
90+
payload = {
91+
"model": self.kwargs["model"],
92+
"prompt": prompt,
93+
"stream": False,
94+
"options": {
95+
"temperature": kwargs.get("temperature", self.kwargs.get("temperature", 0.0)),
96+
"num_predict": kwargs.get("max_tokens", self.kwargs.get("max_tokens", 400)),
97+
},
98+
}
99+
resp = requests.post(url, json=payload, timeout=120)
100+
resp.raise_for_status()
101+
return resp.json()
102+
103+
def __call__(self, prompt: str, only_completed: bool = True, return_sorted: bool = False, **kwargs) -> List[str]:
104+
data = self.basic_request(prompt, **kwargs)
105+
text = data.get("response", "")
106+
return [text]

0 commit comments

Comments
 (0)