|
| 1 | +# Standard |
| 2 | +from typing import Any, Dict, List, Optional |
| 3 | +import json |
| 4 | +import os |
| 5 | +import pathlib |
| 6 | + |
| 7 | +# Third Party |
| 8 | +from lm_eval.evaluator import simple_evaluate |
| 9 | + |
| 10 | +# First Party |
| 11 | +from instructlab.eval.evaluator import Evaluator |
| 12 | + |
| 13 | +RULER_TASKS = [ |
| 14 | + "niah_single_1", |
| 15 | + "niah_single_2", |
| 16 | + "niah_single_3", |
| 17 | + "niah_multikey_1", |
| 18 | + "niah_multikey_2", |
| 19 | + "niah_multikey_3", |
| 20 | + "niah_multiquery", |
| 21 | + "niah_multivalue", |
| 22 | + "ruler_vt", |
| 23 | + "ruler_cwe", |
| 24 | + "ruler_fwe", |
| 25 | + "ruler_qa_hotpot", |
| 26 | + "ruler_qa_squad", |
| 27 | +] |
| 28 | + |
| 29 | +DEFAULT_MAX_LENGTH = 4096 |
| 30 | + |
| 31 | + |
| 32 | +class RulerEvaluator(Evaluator): |
| 33 | + """ |
| 34 | + Class definition for running RULER benchmarking tasks. |
| 35 | + """ |
| 36 | + |
| 37 | + name = "ruler" |
| 38 | + |
| 39 | + def __init__( |
| 40 | + self, |
| 41 | + model_path: Optional[str] = None, |
| 42 | + output_file: Optional[str] = None, |
| 43 | + tasks: list[str] = RULER_TASKS, |
| 44 | + api_endpoint: Optional[str] = None, |
| 45 | + max_length: Optional[int] = None, |
| 46 | + ) -> None: |
| 47 | + self.model_path = model_path |
| 48 | + self.tasks = tasks |
| 49 | + self.results: Dict[Any, Any] = {} |
| 50 | + self.output_file = output_file |
| 51 | + |
| 52 | + self.api_endpoint = api_endpoint or None |
| 53 | + self.max_length = max_length or 4096 |
| 54 | + |
| 55 | + def save_to_file(self, output_file: Optional[str] = None) -> None: |
| 56 | + """Save results to a JSON file""" |
| 57 | + output_file = output_file if output_file else self.output_file |
| 58 | + if not output_file: |
| 59 | + raise ValueError("Output file path cannot be empty") |
| 60 | + |
| 61 | + os.makedirs(os.path.dirname(output_file), exist_ok=True) |
| 62 | + with open(output_file, "w", encoding="utf-8") as f: |
| 63 | + json.dump(self.results, f, indent=2) |
| 64 | + |
| 65 | + def process_lm_eval_results( |
| 66 | + self, |
| 67 | + fpath: Optional[pathlib.Path] = None, |
| 68 | + raw_results: Optional[dict[str, Any]] = None, |
| 69 | + ) -> dict[str, float]: |
| 70 | + """ |
| 71 | + Process the evaluation results from lm_eval for the given file path and extract |
| 72 | + aggregarted scores for each context length |
| 73 | + Args: |
| 74 | + fpath (pathlib.Path): The file path to the evaluation results. |
| 75 | +
|
| 76 | + """ |
| 77 | + unqiue_metrics_dict: dict[str, Any] = {} |
| 78 | + |
| 79 | + # This is required because the lm_eval results are nested under 'ruler' if |
| 80 | + # that is the supplied task to it. The output contains a nested dictionary |
| 81 | + # in this case, using RULER tasks as the key. Each context length is a further subkey |
| 82 | + # in the dictionary. There is an additional key per context length which also |
| 83 | + # contains score adjusted for stderr, which we are ignoring here. |
| 84 | + def extract_metrics(results: dict, unqiue_metrics_dict: dict = {}): |
| 85 | + for k, v in results.items(): |
| 86 | + if isinstance(v, dict): |
| 87 | + extract_metrics(v, unqiue_metrics_dict) |
| 88 | + else: |
| 89 | + if "stderr" not in k: |
| 90 | + metric = k.split(",")[0] |
| 91 | + if metric not in unqiue_metrics_dict: |
| 92 | + unqiue_metrics_dict[metric] = [] |
| 93 | + unqiue_metrics_dict[metric].append(v) |
| 94 | + |
| 95 | + return unqiue_metrics_dict |
| 96 | + |
| 97 | + if fpath: |
| 98 | + with open(fpath, "r", encoding="utf-8") as f: |
| 99 | + raw_results = json.load(f) |
| 100 | + |
| 101 | + if raw_results is not None: |
| 102 | + extract_metrics(raw_results["results"], unqiue_metrics_dict) |
| 103 | + unique_float_metrics = {} |
| 104 | + # if value is list of floats, average the list |
| 105 | + for k, v in unqiue_metrics_dict.items(): |
| 106 | + if isinstance(v, list) and all(isinstance(i, float) for i in v): |
| 107 | + unique_float_metrics[k] = sum(v) / len(v) |
| 108 | + |
| 109 | + # find average of all float values in dict |
| 110 | + float_values = [ |
| 111 | + v for v in unique_float_metrics.values() if isinstance(v, float) |
| 112 | + ] |
| 113 | + if float_values: |
| 114 | + unique_float_metrics["avg"] = sum(float_values) / len(float_values) |
| 115 | + else: |
| 116 | + unique_float_metrics["avg"] = 0.0 |
| 117 | + |
| 118 | + # result format |
| 119 | + # {'8192': 0.90, '32768': 0.82, '65536': 0.77, '131072': 0.71, 'avg': 0.80} |
| 120 | + return unique_float_metrics |
| 121 | + |
| 122 | + def run( |
| 123 | + self, |
| 124 | + model_path: Optional[str] = None, |
| 125 | + tasks: Optional[List[str]] = None, |
| 126 | + output_file: Optional[str] = None, |
| 127 | + api_endpoint: Optional[str] = None, |
| 128 | + max_length: Optional[int] = DEFAULT_MAX_LENGTH, |
| 129 | + ) -> None: |
| 130 | + """ |
| 131 | + Run the RULER evaluation using the specified model and tasks. |
| 132 | + """ |
| 133 | + |
| 134 | + model_path = self.model_path if model_path is None else model_path |
| 135 | + tasks = self.tasks if not tasks else tasks |
| 136 | + output_file = self.output_file if not output_file else output_file |
| 137 | + |
| 138 | + # validate above params are not none and output file can be written to |
| 139 | + if not model_path: |
| 140 | + raise ValueError("Model path cannot be empty") |
| 141 | + if not output_file: |
| 142 | + raise ValueError("Output file path cannot be empty") |
| 143 | + if not api_endpoint: |
| 144 | + raise ValueError("API endpoint cannot be empty") |
| 145 | + |
| 146 | + # Prepare model_args |
| 147 | + model_args = { |
| 148 | + "pretrained": model_path, |
| 149 | + "base_url": api_endpoint, |
| 150 | + "max_length": max_length, |
| 151 | + } |
| 152 | + |
| 153 | + self.lm_eval_results = simple_evaluate( |
| 154 | + model="local-completions", |
| 155 | + model_args=model_args, |
| 156 | + tasks=tasks, |
| 157 | + ) |
| 158 | + |
| 159 | + self.result = self.process_lm_eval_results( |
| 160 | + raw_results=self.lm_eval_results, |
| 161 | + ) |
| 162 | + |
| 163 | + # write results to file |
| 164 | + if output_file: |
| 165 | + try: |
| 166 | + with open(output_file, "w", encoding="utf-8") as f: |
| 167 | + json.dump(self.result, f, indent=2) |
| 168 | + except (OSError, IOError) as e: |
| 169 | + raise ValueError(f"Failed to write to output file: {e}") from e |
0 commit comments