Skip to content

Commit 19e2b99

Browse files
authored
Merge pull request #250 from jaideepr97/add-ruler
feat: add RULER long context evaluation
2 parents bc03a1f + d0c1f5c commit 19e2b99

File tree

4 files changed

+173
-0
lines changed

4 files changed

+173
-0
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ issues = "https://github.com/instructlab/eval/issues"
4141
"mt_bench" = "instructlab.eval.mt_bench:MTBenchEvaluator"
4242
"mt_bench_branch" = "instructlab.eval.mt_bench:MTBenchBranchEvaluator"
4343
"leaderboard_v2" = "instructlab.eval.leaderboard:LeaderboardV2Evaluator"
44+
"ruler" = "instructlab.eval.ruler:RulerEvaluator"
4445

4546
[tool.setuptools_scm]
4647
version_file = "src/instructlab/eval/_version.py"

requirements-ruler.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
lm-eval[ruler]>=0.4.8

src/instructlab/eval/ruler.py

Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
# Standard
2+
from typing import Any, Dict, List, Optional
3+
import json
4+
import os
5+
import pathlib
6+
7+
# Third Party
8+
from lm_eval.evaluator import simple_evaluate
9+
10+
# First Party
11+
from instructlab.eval.evaluator import Evaluator
12+
13+
RULER_TASKS = [
14+
"niah_single_1",
15+
"niah_single_2",
16+
"niah_single_3",
17+
"niah_multikey_1",
18+
"niah_multikey_2",
19+
"niah_multikey_3",
20+
"niah_multiquery",
21+
"niah_multivalue",
22+
"ruler_vt",
23+
"ruler_cwe",
24+
"ruler_fwe",
25+
"ruler_qa_hotpot",
26+
"ruler_qa_squad",
27+
]
28+
29+
DEFAULT_MAX_LENGTH = 4096
30+
31+
32+
class RulerEvaluator(Evaluator):
33+
"""
34+
Class definition for running RULER benchmarking tasks.
35+
"""
36+
37+
name = "ruler"
38+
39+
def __init__(
40+
self,
41+
model_path: Optional[str] = None,
42+
output_file: Optional[str] = None,
43+
tasks: list[str] = RULER_TASKS,
44+
api_endpoint: Optional[str] = None,
45+
max_length: Optional[int] = None,
46+
) -> None:
47+
self.model_path = model_path
48+
self.tasks = tasks
49+
self.results: Dict[Any, Any] = {}
50+
self.output_file = output_file
51+
52+
self.api_endpoint = api_endpoint or None
53+
self.max_length = max_length or 4096
54+
55+
def save_to_file(self, output_file: Optional[str] = None) -> None:
56+
"""Save results to a JSON file"""
57+
output_file = output_file if output_file else self.output_file
58+
if not output_file:
59+
raise ValueError("Output file path cannot be empty")
60+
61+
os.makedirs(os.path.dirname(output_file), exist_ok=True)
62+
with open(output_file, "w", encoding="utf-8") as f:
63+
json.dump(self.results, f, indent=2)
64+
65+
def process_lm_eval_results(
66+
self,
67+
fpath: Optional[pathlib.Path] = None,
68+
raw_results: Optional[dict[str, Any]] = None,
69+
) -> dict[str, float]:
70+
"""
71+
Process the evaluation results from lm_eval for the given file path and extract
72+
aggregarted scores for each context length
73+
Args:
74+
fpath (pathlib.Path): The file path to the evaluation results.
75+
76+
"""
77+
unqiue_metrics_dict: dict[str, Any] = {}
78+
79+
# This is required because the lm_eval results are nested under 'ruler' if
80+
# that is the supplied task to it. The output contains a nested dictionary
81+
# in this case, using RULER tasks as the key. Each context length is a further subkey
82+
# in the dictionary. There is an additional key per context length which also
83+
# contains score adjusted for stderr, which we are ignoring here.
84+
def extract_metrics(results: dict, unqiue_metrics_dict: dict = {}):
85+
for k, v in results.items():
86+
if isinstance(v, dict):
87+
extract_metrics(v, unqiue_metrics_dict)
88+
else:
89+
if "stderr" not in k:
90+
metric = k.split(",")[0]
91+
if metric not in unqiue_metrics_dict:
92+
unqiue_metrics_dict[metric] = []
93+
unqiue_metrics_dict[metric].append(v)
94+
95+
return unqiue_metrics_dict
96+
97+
if fpath:
98+
with open(fpath, "r", encoding="utf-8") as f:
99+
raw_results = json.load(f)
100+
101+
if raw_results is not None:
102+
extract_metrics(raw_results["results"], unqiue_metrics_dict)
103+
unique_float_metrics = {}
104+
# if value is list of floats, average the list
105+
for k, v in unqiue_metrics_dict.items():
106+
if isinstance(v, list) and all(isinstance(i, float) for i in v):
107+
unique_float_metrics[k] = sum(v) / len(v)
108+
109+
# find average of all float values in dict
110+
float_values = [
111+
v for v in unique_float_metrics.values() if isinstance(v, float)
112+
]
113+
if float_values:
114+
unique_float_metrics["avg"] = sum(float_values) / len(float_values)
115+
else:
116+
unique_float_metrics["avg"] = 0.0
117+
118+
# result format
119+
# {'8192': 0.90, '32768': 0.82, '65536': 0.77, '131072': 0.71, 'avg': 0.80}
120+
return unique_float_metrics
121+
122+
def run(
123+
self,
124+
model_path: Optional[str] = None,
125+
tasks: Optional[List[str]] = None,
126+
output_file: Optional[str] = None,
127+
api_endpoint: Optional[str] = None,
128+
max_length: Optional[int] = DEFAULT_MAX_LENGTH,
129+
) -> None:
130+
"""
131+
Run the RULER evaluation using the specified model and tasks.
132+
"""
133+
134+
model_path = self.model_path if model_path is None else model_path
135+
tasks = self.tasks if not tasks else tasks
136+
output_file = self.output_file if not output_file else output_file
137+
138+
# validate above params are not none and output file can be written to
139+
if not model_path:
140+
raise ValueError("Model path cannot be empty")
141+
if not output_file:
142+
raise ValueError("Output file path cannot be empty")
143+
if not api_endpoint:
144+
raise ValueError("API endpoint cannot be empty")
145+
146+
# Prepare model_args
147+
model_args = {
148+
"pretrained": model_path,
149+
"base_url": api_endpoint,
150+
"max_length": max_length,
151+
}
152+
153+
self.lm_eval_results = simple_evaluate(
154+
model="local-completions",
155+
model_args=model_args,
156+
tasks=tasks,
157+
)
158+
159+
self.result = self.process_lm_eval_results(
160+
raw_results=self.lm_eval_results,
161+
)
162+
163+
# write results to file
164+
if output_file:
165+
try:
166+
with open(output_file, "w", encoding="utf-8") as f:
167+
json.dump(self.result, f, indent=2)
168+
except (OSError, IOError) as e:
169+
raise ValueError(f"Failed to write to output file: {e}") from e

tests/test_project.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from instructlab.eval.leaderboard import LeaderboardV2Evaluator
88
from instructlab.eval.mmlu import MMLUBranchEvaluator, MMLUEvaluator
99
from instructlab.eval.mt_bench import MTBenchBranchEvaluator, MTBenchEvaluator
10+
from instructlab.eval.ruler import RulerEvaluator
1011

1112

1213
def test_evaluator_eps():
@@ -16,6 +17,7 @@ def test_evaluator_eps():
1617
"mt_bench": MTBenchEvaluator,
1718
"mt_bench_branch": MTBenchBranchEvaluator,
1819
"leaderboard_v2": LeaderboardV2Evaluator,
20+
"ruler": RulerEvaluator,
1921
}
2022
eps = entry_points(group="instructlab.eval.evaluator")
2123
found = {}

0 commit comments

Comments
 (0)