Skip to content

Commit 109fad0

Browse files
committed
Add e2e tests for embedding raw flag
1 parent 8284efc commit 109fad0

File tree

4 files changed

+353
-5
lines changed

4 files changed

+353
-5
lines changed

.github/workflows/embeddings.yml

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# Embedding CLI build and tests
2+
name: Embedding CLI
3+
4+
on:
5+
workflow_dispatch:
6+
push:
7+
branches:
8+
- feature/*
9+
- master
10+
paths:
11+
- '.github/workflows/embeddings.yml'
12+
- 'examples/embedding/**'
13+
- 'examples/tests/**'
14+
pull_request:
15+
types: [opened, synchronize, reopened]
16+
paths:
17+
- '.github/workflows/embeddings.yml'
18+
- 'examples/embedding/**'
19+
- 'examples/tests/**'
20+
21+
jobs:
22+
embedding-cli-tests:
23+
runs-on: ubuntu-latest
24+
25+
steps:
26+
- name: Install system deps
27+
run: |
28+
sudo apt-get update
29+
sudo apt-get -y install \
30+
build-essential \
31+
cmake \
32+
curl \
33+
python3-pip \
34+
libcurl4-openssl-dev
35+
36+
- name: Checkout repository
37+
uses: actions/checkout@v4
38+
with:
39+
fetch-depth: 0
40+
41+
- name: Set up Python
42+
uses: actions/setup-python@v5
43+
with:
44+
python-version: '3.11'
45+
46+
- name: Install Python deps
47+
run: |
48+
pip install -r requirements.txt || echo "No extra requirements found"
49+
pip install pytest
50+
51+
- name: Build llama-embedding
52+
run: |
53+
cmake -B build \
54+
-DCMAKE_BUILD_TYPE=Release
55+
cmake --build build --target llama-embedding -j $(nproc)
56+
57+
- name: Run embedding tests
58+
run: |
59+
pytest -v examples/tests

examples/embedding/embedding.cpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -81,13 +81,11 @@ static void print_raw_embeddings(const float * emb,
8181
const bool is_rank = (pooling_type == LLAMA_POOLING_TYPE_RANK);
8282
const int cols = is_rank ? std::min<int>(n_embd, (int) n_cls_out) : n_embd;
8383

84+
const char *fmt = embd_normalize == 0 ? "%1.0f%s" : "%1.7f%s";
85+
8486
for (int j = 0; j < n_embd_count; ++j) {
8587
for (int i = 0; i < cols; ++i) {
86-
if (embd_normalize == 0) {
87-
LOG("%1.0f%s", emb[j * n_embd + i], (i + 1 < cols ? " " : ""));
88-
} else {
89-
LOG("%1.7f%s", emb[j * n_embd + i], (i + 1 < cols ? " " : ""));
90-
}
88+
LOG(fmt, emb[j * n_embd + i], (i + 1 < cols ? " " : ""));
9189
}
9290
LOG("\n");
9391
}

examples/tests/__init__.py

Whitespace-only changes.

examples/tests/test_embedding.py

Lines changed: 291 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,291 @@
1+
from concurrent.futures import ProcessPoolExecutor, as_completed
2+
from functools import wraps
3+
import numpy as np
4+
from pathlib import Path
5+
import json, os, time, statistics, subprocess, math
6+
7+
8+
# ---------------------------------------------------------------------------
9+
# Benchmark decorator
10+
# ---------------------------------------------------------------------------
11+
12+
def benchmark(n=3):
13+
def decorator(fn):
14+
@wraps(fn)
15+
def wrapper(*args, **kwargs):
16+
times = []
17+
result = None
18+
for _ in range(n):
19+
start = time.perf_counter()
20+
result = fn(*args, **kwargs)
21+
times.append(time.perf_counter() - start)
22+
avg = statistics.mean(times)
23+
print(f"\n[benchmark] {fn.__name__}: mean={avg*1000:.1f} ms over {n} runs")
24+
return result
25+
return wrapper
26+
return decorator
27+
28+
29+
# ---------------------------------------------------------------------------
30+
# Model helpers
31+
# ---------------------------------------------------------------------------
32+
33+
def get_model_hf_params():
34+
"""Default lightweight embedding model."""
35+
return {
36+
"hf_repo": "ggml-org/embeddinggemma-300M-qat-q4_0-GGUF",
37+
"hf_file": "embeddinggemma-300M-qat-Q4_0.gguf",
38+
}
39+
40+
41+
def ensure_model_downloaded(params=None):
42+
repo_root = Path(__file__).resolve().parents[2]
43+
cache_dir = os.environ.get("LLAMA_CACHE", "tmp")
44+
emb_path = repo_root / "build/bin/llama-embedding"
45+
if not emb_path.exists() and os.name == "nt":
46+
emb_path = repo_root / "build/bin/Release/llama-embedding.exe"
47+
if not emb_path.exists():
48+
raise FileNotFoundError(f"llama-embedding not found at {emb_path}")
49+
50+
params = params or get_model_hf_params()
51+
cmd = [
52+
str(emb_path),
53+
"-hfr", params["hf_repo"],
54+
"-hff", params["hf_file"],
55+
"--ctx-size", "16",
56+
"--embd-output-format", "json",
57+
"--no-warmup",
58+
"--threads", "1",
59+
]
60+
61+
env = os.environ.copy()
62+
env["LLAMA_CACHE"] = cache_dir
63+
result = subprocess.run(cmd, input="ok", capture_output=True, text=True, env=env)
64+
if result.returncode != 0:
65+
raise RuntimeError(f"Failed to download model:\n{result.stderr}")
66+
return params
67+
68+
69+
def run_embedding(text: str, fmt: str = "raw", params=None):
70+
repo_root = Path(__file__).resolve().parents[2]
71+
exe = repo_root / "build/bin/llama-embedding"
72+
assert exe.exists(), f"Missing binary: {exe}"
73+
74+
params = ensure_model_downloaded(params)
75+
cache_dir = os.environ.get("LLAMA_CACHE", "tmp")
76+
77+
cmd = [
78+
str(exe),
79+
"-hfr", params["hf_repo"],
80+
"-hff", params["hf_file"],
81+
"--ctx-size", "2048",
82+
"--embd-output-format", fmt,
83+
]
84+
85+
env = os.environ.copy()
86+
env["LLAMA_CACHE"] = cache_dir
87+
88+
out = subprocess.run(cmd, input=text, capture_output=True, text=True, env=env)
89+
if out.returncode != 0:
90+
print(out.stderr)
91+
raise AssertionError(f"embedding binary failed (code {out.returncode})")
92+
return out.stdout.strip()
93+
94+
95+
# ---------------------------------------------------------------------------
96+
# 1️⃣ RAW vs JSON baseline tests
97+
# ---------------------------------------------------------------------------
98+
99+
@benchmark(n=3)
100+
def test_embedding_raw_and_json_consistency():
101+
"""
102+
Run both output modes and verify same embedding shape, norm similarity,
103+
and small cosine distance.
104+
"""
105+
out_raw = run_embedding("hello world", "raw")
106+
floats_raw = np.array([float(x) for x in out_raw.split()])
107+
108+
out_json = run_embedding("hello world", "json")
109+
j = json.loads(out_json)
110+
floats_json = np.array(j["data"][0]["embedding"])
111+
112+
assert len(floats_raw) == len(floats_json), "Embedding dimension mismatch"
113+
cos = np.dot(floats_raw, floats_json) / (np.linalg.norm(floats_raw) * np.linalg.norm(floats_json))
114+
print(f"Cosine similarity raw vs json: {cos:.4f}")
115+
# expect high similarity but not perfect (formatting precision differences)
116+
assert cos > 0.999, f"Unexpected divergence between raw and json output ({cos:.4f})"
117+
118+
119+
@benchmark(n=3)
120+
def test_embedding_perf_regression_raw_vs_json():
121+
"""
122+
Compare performance between raw and json output.
123+
Ensures raw mode is not significantly slower or memory-heavier.
124+
"""
125+
text = "performance regression test " * 512
126+
params = ensure_model_downloaded()
127+
128+
def run(fmt):
129+
start = time.perf_counter()
130+
out = run_embedding(text, fmt, params)
131+
dur = time.perf_counter() - start
132+
mem = len(out)
133+
return dur, mem
134+
135+
t_raw, m_raw = run("raw")
136+
t_json, m_json = run("json")
137+
138+
print(f"[perf] raw={t_raw:.3f}s ({m_raw/1e3:.1f} KB) | json={t_json:.3f}s ({m_json/1e3:.1f} KB)")
139+
# raw should never be significantly slower or consume wildly more memory
140+
assert t_raw <= t_json * 1.2, f"raw too slow vs json ({t_raw:.3f}s vs {t_json:.3f}s)"
141+
assert m_raw <= m_json * 1.2, f"raw output unexpectedly larger ({m_raw} vs {m_json} bytes)"
142+
143+
144+
# ---------------------------------------------------------------------------
145+
# 2️⃣ Edge-case coverage
146+
# ---------------------------------------------------------------------------
147+
148+
def test_embedding_empty_input():
149+
"""
150+
Empty input should not crash and should yield a deterministic, finite embedding.
151+
Some models (e.g. Gemma/BGE) emit BOS token embedding with norm ≈ 1.0.
152+
"""
153+
out1 = run_embedding("", "raw")
154+
out2 = run_embedding("", "raw")
155+
156+
floats1 = np.array([float(x) for x in out1.split()])
157+
floats2 = np.array([float(x) for x in out2.split()])
158+
159+
# Basic validity
160+
assert len(floats1) > 0, "Empty input produced no embedding"
161+
assert np.all(np.isfinite(floats1)), "Embedding contains NaN or inf"
162+
norm = np.linalg.norm(floats1)
163+
assert 0.5 <= norm <= 1.5, f"Unexpected norm for empty input: {norm}"
164+
165+
# Determinism check: cosine similarity should be ≈ 1
166+
cos = np.dot(floats1, floats2) / (np.linalg.norm(floats1) * np.linalg.norm(floats2))
167+
assert cos > 0.9999, f"Empty input not deterministic (cos={cos:.4f})"
168+
print(f"[empty] norm={norm:.4f}, cos={cos:.6f}")
169+
170+
171+
def test_embedding_special_characters():
172+
"""Unicode and punctuation coverage."""
173+
special_text = "你好 🌍\n\t!@#$%^&*()_+-=[]{}|;:'\",.<>?/`~"
174+
out = run_embedding(special_text, "raw")
175+
floats = [float(x) for x in out.split()]
176+
assert len(floats) > 10
177+
norm = np.linalg.norm(floats)
178+
assert math.isfinite(norm) and norm > 0
179+
180+
181+
@benchmark(n=1)
182+
def test_embedding_very_long_input():
183+
"""Stress test for context limit handling."""
184+
long_text = "lorem " * 10000
185+
out = run_embedding(long_text, "raw")
186+
floats = [float(x) for x in out.split()]
187+
print(f"Output floats (long input): {len(floats)}")
188+
assert len(floats) > 100
189+
assert np.isfinite(np.linalg.norm(floats))
190+
191+
192+
# ---------------------------------------------------------------------------
193+
# 3️⃣ Legacy and concurrency coverage (unchanged)
194+
# ---------------------------------------------------------------------------
195+
196+
@benchmark(n=3)
197+
def test_embedding_raw_vector_shape():
198+
out = run_embedding("hello world", "raw")
199+
floats = [float(x) for x in out.split()]
200+
print(f"Embedding size: {len(floats)} floats")
201+
assert len(floats) > 100
202+
norm = np.linalg.norm(floats)
203+
assert 0.5 < norm < 2.0
204+
205+
206+
@benchmark(n=3)
207+
def test_embedding_large_vector_output():
208+
text = " ".join(["hello"] * 4096)
209+
out = run_embedding(text, "raw")
210+
valid_dims = {384, 768, 1024, 1280, 2048, 4096}
211+
floats = [float(x) for x in out.split()]
212+
print(f"Output floats: {len(floats)}")
213+
assert len(floats) in valid_dims, (
214+
f"Unexpected embedding size: {len(floats)}. Expected one of {sorted(valid_dims)}."
215+
)
216+
217+
218+
def run_one(args):
219+
i, params, text = args
220+
repo_root = Path(__file__).resolve().parents[2]
221+
exe = repo_root / "build/bin/llama-embedding"
222+
cache_dir = os.environ.get("LLAMA_CACHE", "tmp")
223+
224+
cmd = [
225+
str(exe),
226+
"-hfr", params["hf_repo"],
227+
"-hff", params["hf_file"],
228+
"--ctx-size", "1024",
229+
"--embd-output-format", "raw",
230+
"--threads", "1",
231+
]
232+
233+
env = os.environ.copy()
234+
env["LLAMA_CACHE"] = cache_dir
235+
start = time.perf_counter()
236+
result = subprocess.run(cmd, input=text, capture_output=True, text=True, env=env)
237+
if result.returncode != 0:
238+
print(f"[worker {i}] stderr:\n{result.stderr}")
239+
raise AssertionError(f"embedding run {i} failed (code {result.returncode})")
240+
return time.perf_counter() - start
241+
242+
243+
@benchmark(n=1)
244+
def test_embedding_concurrent_invocations():
245+
params = ensure_model_downloaded()
246+
text = " ".join(["concurrency"] * 128)
247+
n_workers = 4
248+
with ProcessPoolExecutor(max_workers=n_workers) as pool:
249+
futures = [pool.submit(run_one, (i, params, text)) for i in range(n_workers)]
250+
times = [f.result() for f in as_completed(futures)]
251+
avg = statistics.mean(times)
252+
print(f"[concurrency] {n_workers} parallel runs: mean={avg*1000:.1f} ms")
253+
254+
255+
@benchmark(n=1)
256+
def test_embedding_large_model_logging_stress():
257+
"""Optional stress test using larger model for stdout/mutex path."""
258+
large_model = {
259+
"hf_repo": "TheBloke/Mistral-7B-Instruct-v0.2-GGUF",
260+
"hf_file": "mistral-7b-instruct-v0.2.Q4_K_M.gguf",
261+
}
262+
text = " ".join(["benchmark"] * 8192)
263+
out = run_embedding(text, "raw", params=large_model)
264+
floats = [float(x) for x in out.split()]
265+
assert len(floats) >= 1024
266+
267+
268+
def test_embedding_invalid_flag():
269+
"""
270+
Invalid flag should produce a non-zero exit and a helpful error message.
271+
Ensures CLI argument parsing fails gracefully instead of crashing.
272+
"""
273+
repo_root = Path(__file__).resolve().parents[2]
274+
exe = repo_root / "build/bin/llama-embedding"
275+
assert exe.exists(), f"Missing binary: {exe}"
276+
277+
# Pass an obviously invalid flag to trigger error handling.
278+
result = subprocess.run(
279+
[str(exe), "--no-such-flag"],
280+
capture_output=True,
281+
text=True,
282+
)
283+
284+
# Must return non-zero and print something meaningful to stderr.
285+
assert result.returncode != 0, "Expected non-zero exit on invalid flag"
286+
stderr_lower = result.stderr.lower()
287+
assert (
288+
"error" in stderr_lower
289+
or "invalid" in stderr_lower
290+
or "unknown" in stderr_lower
291+
), f"Unexpected stderr output: {result.stderr}"

0 commit comments

Comments
 (0)