Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions src/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,12 @@
ANSWER_START = "<<<ANSWER>>>"
ANSWER_END = "<<<END>>>"

try:
# Prefer persistent in-process generation when available
from llama_cpp import Llama # type: ignore
except Exception: # pragma: no cover
Llama = None # Fallback to CLI

def _project_root() -> pathlib.Path:
# generator.py is in src/, so project root is parent of that folder
here = pathlib.Path(__file__).resolve()
Expand Down Expand Up @@ -167,8 +173,57 @@ def _extract_answer(raw: str) -> str:
text = raw.split(ANSWER_START)[-1]
return text.split(ANSWER_END)[0].strip()

_LLM_CACHE = {}
def _get_llm(model_path: str, n_ctx: int = 4096, n_threads: int | None = None, n_gpu_layers: int = 0):
key = (model_path, n_ctx, n_threads or os.cpu_count(), n_gpu_layers)
if key in _LLM_CACHE:
return _LLM_CACHE[key]
if Llama is None:
return None
if not os.path.isfile(model_path):
raise FileNotFoundError(
f"Model file not found: {model_path}.\n"
"Download a compatible GGUF model and update config/model_path, e.g.:\n"
"models/qwen2.5-0.5b-instruct-q5_k_m.gguf\n"
"Or run with: --model_path <path-to-gguf>"
)

# Prefer a larger context when using Qwen 2.5 (trained at 32k). Allow override via env.
try:
n_ctx_env = int(os.getenv("TOKENS_N_CTX") or os.getenv("LLAMA_N_CTX") or "32768")
except ValueError:
n_ctx_env = 32768
n_ctx = max(n_ctx, n_ctx_env)
llm = Llama(
model_path=model_path,
n_ctx=n_ctx,
n_threads=n_threads or os.cpu_count() or 4,
n_gpu_layers=n_gpu_layers, # default CPU-only for broad compatibility
logits_all=False,
embedding=False,
verbose=False,
n_batch=256,
use_mmap=True,
)
_LLM_CACHE[key] = llm
return llm

def run_llama_cpp(prompt: str, model_path: str, max_tokens: int = 300,
threads: int = 8, n_gpu_layers: int = 8, temperature: float = 0.2):
llm = _get_llm(model_path=model_path, n_ctx=32768, n_threads=threads or None, n_gpu_layers=n_gpu_layers)
if llm is not None:
out = llm.create_completion(
prompt=prompt,
max_tokens=max_tokens,
temperature=temperature,
stop=[ANSWER_END],
)
text =(out.get("choices") or [{}])[0].get("text", "")
if not text.strip():
raise RuntimeError("llama create completion returned empty text.")
return _extract_answer(text + ANSWER_END)

# Fall back to use CLI
if not os.path.exists(model_path):
raise FileNotFoundError(f"Model file not found: {model_path}. Follow README steps to download the model.")
llama_binary = resolve_llama_binary()
Expand Down