|
| 1 | +from concurrent.futures import ProcessPoolExecutor, as_completed |
| 2 | +from functools import wraps |
| 3 | +import numpy as np |
| 4 | +from pathlib import Path |
| 5 | +import json, os, time, statistics, subprocess, math |
| 6 | + |
| 7 | + |
| 8 | +# --------------------------------------------------------------------------- |
| 9 | +# Benchmark decorator |
| 10 | +# --------------------------------------------------------------------------- |
| 11 | + |
| 12 | +def benchmark(n=3): |
| 13 | + def decorator(fn): |
| 14 | + @wraps(fn) |
| 15 | + def wrapper(*args, **kwargs): |
| 16 | + times = [] |
| 17 | + result = None |
| 18 | + for _ in range(n): |
| 19 | + start = time.perf_counter() |
| 20 | + result = fn(*args, **kwargs) |
| 21 | + times.append(time.perf_counter() - start) |
| 22 | + avg = statistics.mean(times) |
| 23 | + print(f"\n[benchmark] {fn.__name__}: mean={avg*1000:.1f} ms over {n} runs") |
| 24 | + return result |
| 25 | + return wrapper |
| 26 | + return decorator |
| 27 | + |
| 28 | + |
| 29 | +# --------------------------------------------------------------------------- |
| 30 | +# Model helpers |
| 31 | +# --------------------------------------------------------------------------- |
| 32 | + |
| 33 | +def get_model_hf_params(): |
| 34 | + """Default lightweight embedding model.""" |
| 35 | + return { |
| 36 | + "hf_repo": "ggml-org/embeddinggemma-300M-qat-q4_0-GGUF", |
| 37 | + "hf_file": "embeddinggemma-300M-qat-Q4_0.gguf", |
| 38 | + } |
| 39 | + |
| 40 | + |
| 41 | +def ensure_model_downloaded(params=None): |
| 42 | + repo_root = Path(__file__).resolve().parents[2] |
| 43 | + cache_dir = os.environ.get("LLAMA_CACHE", "tmp") |
| 44 | + emb_path = repo_root / "build/bin/llama-embedding" |
| 45 | + if not emb_path.exists() and os.name == "nt": |
| 46 | + emb_path = repo_root / "build/bin/Release/llama-embedding.exe" |
| 47 | + if not emb_path.exists(): |
| 48 | + raise FileNotFoundError(f"llama-embedding not found at {emb_path}") |
| 49 | + |
| 50 | + params = params or get_model_hf_params() |
| 51 | + cmd = [ |
| 52 | + str(emb_path), |
| 53 | + "-hfr", params["hf_repo"], |
| 54 | + "-hff", params["hf_file"], |
| 55 | + "--ctx-size", "16", |
| 56 | + "--embd-output-format", "json", |
| 57 | + "--no-warmup", |
| 58 | + "--threads", "1", |
| 59 | + ] |
| 60 | + |
| 61 | + env = os.environ.copy() |
| 62 | + env["LLAMA_CACHE"] = cache_dir |
| 63 | + result = subprocess.run(cmd, input="ok", capture_output=True, text=True, env=env) |
| 64 | + if result.returncode != 0: |
| 65 | + raise RuntimeError(f"Failed to download model:\n{result.stderr}") |
| 66 | + return params |
| 67 | + |
| 68 | + |
| 69 | +def run_embedding(text: str, fmt: str = "raw", params=None): |
| 70 | + repo_root = Path(__file__).resolve().parents[2] |
| 71 | + exe = repo_root / "build/bin/llama-embedding" |
| 72 | + assert exe.exists(), f"Missing binary: {exe}" |
| 73 | + |
| 74 | + params = ensure_model_downloaded(params) |
| 75 | + cache_dir = os.environ.get("LLAMA_CACHE", "tmp") |
| 76 | + |
| 77 | + cmd = [ |
| 78 | + str(exe), |
| 79 | + "-hfr", params["hf_repo"], |
| 80 | + "-hff", params["hf_file"], |
| 81 | + "--ctx-size", "2048", |
| 82 | + "--embd-output-format", fmt, |
| 83 | + ] |
| 84 | + |
| 85 | + env = os.environ.copy() |
| 86 | + env["LLAMA_CACHE"] = cache_dir |
| 87 | + |
| 88 | + out = subprocess.run(cmd, input=text, capture_output=True, text=True, env=env) |
| 89 | + if out.returncode != 0: |
| 90 | + print(out.stderr) |
| 91 | + raise AssertionError(f"embedding binary failed (code {out.returncode})") |
| 92 | + return out.stdout.strip() |
| 93 | + |
| 94 | + |
| 95 | +# --------------------------------------------------------------------------- |
| 96 | +# 1️⃣ RAW vs JSON baseline tests |
| 97 | +# --------------------------------------------------------------------------- |
| 98 | + |
| 99 | +@benchmark(n=3) |
| 100 | +def test_embedding_raw_and_json_consistency(): |
| 101 | + """ |
| 102 | + Run both output modes and verify same embedding shape, norm similarity, |
| 103 | + and small cosine distance. |
| 104 | + """ |
| 105 | + out_raw = run_embedding("hello world", "raw") |
| 106 | + floats_raw = np.array([float(x) for x in out_raw.split()]) |
| 107 | + |
| 108 | + out_json = run_embedding("hello world", "json") |
| 109 | + j = json.loads(out_json) |
| 110 | + floats_json = np.array(j["data"][0]["embedding"]) |
| 111 | + |
| 112 | + assert len(floats_raw) == len(floats_json), "Embedding dimension mismatch" |
| 113 | + cos = np.dot(floats_raw, floats_json) / (np.linalg.norm(floats_raw) * np.linalg.norm(floats_json)) |
| 114 | + print(f"Cosine similarity raw vs json: {cos:.4f}") |
| 115 | + # expect high similarity but not perfect (formatting precision differences) |
| 116 | + assert cos > 0.999, f"Unexpected divergence between raw and json output ({cos:.4f})" |
| 117 | + |
| 118 | + |
| 119 | +@benchmark(n=3) |
| 120 | +def test_embedding_perf_regression_raw_vs_json(): |
| 121 | + """ |
| 122 | + Compare performance between raw and json output. |
| 123 | + Ensures raw mode is not significantly slower or memory-heavier. |
| 124 | + """ |
| 125 | + text = "performance regression test " * 512 |
| 126 | + params = ensure_model_downloaded() |
| 127 | + |
| 128 | + def run(fmt): |
| 129 | + start = time.perf_counter() |
| 130 | + out = run_embedding(text, fmt, params) |
| 131 | + dur = time.perf_counter() - start |
| 132 | + mem = len(out) |
| 133 | + return dur, mem |
| 134 | + |
| 135 | + t_raw, m_raw = run("raw") |
| 136 | + t_json, m_json = run("json") |
| 137 | + |
| 138 | + print(f"[perf] raw={t_raw:.3f}s ({m_raw/1e3:.1f} KB) | json={t_json:.3f}s ({m_json/1e3:.1f} KB)") |
| 139 | + # raw should never be significantly slower or consume wildly more memory |
| 140 | + assert t_raw <= t_json * 1.2, f"raw too slow vs json ({t_raw:.3f}s vs {t_json:.3f}s)" |
| 141 | + assert m_raw <= m_json * 1.2, f"raw output unexpectedly larger ({m_raw} vs {m_json} bytes)" |
| 142 | + |
| 143 | + |
| 144 | +# --------------------------------------------------------------------------- |
| 145 | +# 2️⃣ Edge-case coverage |
| 146 | +# --------------------------------------------------------------------------- |
| 147 | + |
| 148 | +def test_embedding_empty_input(): |
| 149 | + """ |
| 150 | + Empty input should not crash and should yield a deterministic, finite embedding. |
| 151 | + Some models (e.g. Gemma/BGE) emit BOS token embedding with norm ≈ 1.0. |
| 152 | + """ |
| 153 | + out1 = run_embedding("", "raw") |
| 154 | + out2 = run_embedding("", "raw") |
| 155 | + |
| 156 | + floats1 = np.array([float(x) for x in out1.split()]) |
| 157 | + floats2 = np.array([float(x) for x in out2.split()]) |
| 158 | + |
| 159 | + # Basic validity |
| 160 | + assert len(floats1) > 0, "Empty input produced no embedding" |
| 161 | + assert np.all(np.isfinite(floats1)), "Embedding contains NaN or inf" |
| 162 | + norm = np.linalg.norm(floats1) |
| 163 | + assert 0.5 <= norm <= 1.5, f"Unexpected norm for empty input: {norm}" |
| 164 | + |
| 165 | + # Determinism check: cosine similarity should be ≈ 1 |
| 166 | + cos = np.dot(floats1, floats2) / (np.linalg.norm(floats1) * np.linalg.norm(floats2)) |
| 167 | + assert cos > 0.9999, f"Empty input not deterministic (cos={cos:.4f})" |
| 168 | + print(f"[empty] norm={norm:.4f}, cos={cos:.6f}") |
| 169 | + |
| 170 | + |
| 171 | +def test_embedding_special_characters(): |
| 172 | + """Unicode and punctuation coverage.""" |
| 173 | + special_text = "你好 🌍\n\t!@#$%^&*()_+-=[]{}|;:'\",.<>?/`~" |
| 174 | + out = run_embedding(special_text, "raw") |
| 175 | + floats = [float(x) for x in out.split()] |
| 176 | + assert len(floats) > 10 |
| 177 | + norm = np.linalg.norm(floats) |
| 178 | + assert math.isfinite(norm) and norm > 0 |
| 179 | + |
| 180 | + |
| 181 | +@benchmark(n=1) |
| 182 | +def test_embedding_very_long_input(): |
| 183 | + """Stress test for context limit handling.""" |
| 184 | + long_text = "lorem " * 10000 |
| 185 | + out = run_embedding(long_text, "raw") |
| 186 | + floats = [float(x) for x in out.split()] |
| 187 | + print(f"Output floats (long input): {len(floats)}") |
| 188 | + assert len(floats) > 100 |
| 189 | + assert np.isfinite(np.linalg.norm(floats)) |
| 190 | + |
| 191 | + |
| 192 | +# --------------------------------------------------------------------------- |
| 193 | +# 3️⃣ Legacy and concurrency coverage (unchanged) |
| 194 | +# --------------------------------------------------------------------------- |
| 195 | + |
| 196 | +@benchmark(n=3) |
| 197 | +def test_embedding_raw_vector_shape(): |
| 198 | + out = run_embedding("hello world", "raw") |
| 199 | + floats = [float(x) for x in out.split()] |
| 200 | + print(f"Embedding size: {len(floats)} floats") |
| 201 | + assert len(floats) > 100 |
| 202 | + norm = np.linalg.norm(floats) |
| 203 | + assert 0.5 < norm < 2.0 |
| 204 | + |
| 205 | + |
| 206 | +@benchmark(n=3) |
| 207 | +def test_embedding_large_vector_output(): |
| 208 | + text = " ".join(["hello"] * 4096) |
| 209 | + out = run_embedding(text, "raw") |
| 210 | + valid_dims = {384, 768, 1024, 1280, 2048, 4096} |
| 211 | + floats = [float(x) for x in out.split()] |
| 212 | + print(f"Output floats: {len(floats)}") |
| 213 | + assert len(floats) in valid_dims, ( |
| 214 | + f"Unexpected embedding size: {len(floats)}. Expected one of {sorted(valid_dims)}." |
| 215 | + ) |
| 216 | + |
| 217 | + |
| 218 | +def run_one(args): |
| 219 | + i, params, text = args |
| 220 | + repo_root = Path(__file__).resolve().parents[2] |
| 221 | + exe = repo_root / "build/bin/llama-embedding" |
| 222 | + cache_dir = os.environ.get("LLAMA_CACHE", "tmp") |
| 223 | + |
| 224 | + cmd = [ |
| 225 | + str(exe), |
| 226 | + "-hfr", params["hf_repo"], |
| 227 | + "-hff", params["hf_file"], |
| 228 | + "--ctx-size", "1024", |
| 229 | + "--embd-output-format", "raw", |
| 230 | + "--threads", "1", |
| 231 | + ] |
| 232 | + |
| 233 | + env = os.environ.copy() |
| 234 | + env["LLAMA_CACHE"] = cache_dir |
| 235 | + start = time.perf_counter() |
| 236 | + result = subprocess.run(cmd, input=text, capture_output=True, text=True, env=env) |
| 237 | + if result.returncode != 0: |
| 238 | + print(f"[worker {i}] stderr:\n{result.stderr}") |
| 239 | + raise AssertionError(f"embedding run {i} failed (code {result.returncode})") |
| 240 | + return time.perf_counter() - start |
| 241 | + |
| 242 | + |
| 243 | +@benchmark(n=1) |
| 244 | +def test_embedding_concurrent_invocations(): |
| 245 | + params = ensure_model_downloaded() |
| 246 | + text = " ".join(["concurrency"] * 128) |
| 247 | + n_workers = 4 |
| 248 | + with ProcessPoolExecutor(max_workers=n_workers) as pool: |
| 249 | + futures = [pool.submit(run_one, (i, params, text)) for i in range(n_workers)] |
| 250 | + times = [f.result() for f in as_completed(futures)] |
| 251 | + avg = statistics.mean(times) |
| 252 | + print(f"[concurrency] {n_workers} parallel runs: mean={avg*1000:.1f} ms") |
| 253 | + |
| 254 | + |
| 255 | +@benchmark(n=1) |
| 256 | +def test_embedding_large_model_logging_stress(): |
| 257 | + """Optional stress test using larger model for stdout/mutex path.""" |
| 258 | + large_model = { |
| 259 | + "hf_repo": "TheBloke/Mistral-7B-Instruct-v0.2-GGUF", |
| 260 | + "hf_file": "mistral-7b-instruct-v0.2.Q4_K_M.gguf", |
| 261 | + } |
| 262 | + text = " ".join(["benchmark"] * 8192) |
| 263 | + out = run_embedding(text, "raw", params=large_model) |
| 264 | + floats = [float(x) for x in out.split()] |
| 265 | + assert len(floats) >= 1024 |
| 266 | + |
| 267 | + |
| 268 | +def test_embedding_invalid_flag(): |
| 269 | + """ |
| 270 | + Invalid flag should produce a non-zero exit and a helpful error message. |
| 271 | + Ensures CLI argument parsing fails gracefully instead of crashing. |
| 272 | + """ |
| 273 | + repo_root = Path(__file__).resolve().parents[2] |
| 274 | + exe = repo_root / "build/bin/llama-embedding" |
| 275 | + assert exe.exists(), f"Missing binary: {exe}" |
| 276 | + |
| 277 | + # Pass an obviously invalid flag to trigger error handling. |
| 278 | + result = subprocess.run( |
| 279 | + [str(exe), "--no-such-flag"], |
| 280 | + capture_output=True, |
| 281 | + text=True, |
| 282 | + ) |
| 283 | + |
| 284 | + # Must return non-zero and print something meaningful to stderr. |
| 285 | + assert result.returncode != 0, "Expected non-zero exit on invalid flag" |
| 286 | + stderr_lower = result.stderr.lower() |
| 287 | + assert ( |
| 288 | + "error" in stderr_lower |
| 289 | + or "invalid" in stderr_lower |
| 290 | + or "unknown" in stderr_lower |
| 291 | + ), f"Unexpected stderr output: {result.stderr}" |
0 commit comments