|
| 1 | +import numpy as np |
| 2 | +from aenum import extend_enum |
| 3 | + |
| 4 | +import tasks_examples.custom_tasks_with_custom_metrics.ifeval.instructions_registry as instructions_registry |
| 5 | +from lighteval.metrics import Metrics |
| 6 | +from lighteval.metrics.utils import ( |
| 7 | + MetricCategory, |
| 8 | + MetricUseCase, |
| 9 | + SampleLevelMetricGrouping, |
| 10 | +) |
| 11 | +from lighteval.tasks.lighteval_task import LightevalTaskConfig |
| 12 | +from lighteval.tasks.requests import Doc |
| 13 | + |
| 14 | + |
| 15 | +# We create the task config |
| 16 | +ifeval = LightevalTaskConfig( |
| 17 | + name="ifeval", |
| 18 | + prompt_function="ifeval_prompt", |
| 19 | + suite=["custom"], |
| 20 | + hf_repo="wis-k/instruction-following-eval", |
| 21 | + hf_subset="default", |
| 22 | + metric=["ifeval_metric"], |
| 23 | + hf_avail_splits=["train"], |
| 24 | + evaluation_splits=["train"], |
| 25 | + few_shots_split="train", |
| 26 | + few_shots_select="random_sampling", |
| 27 | + generation_size=1280, |
| 28 | + stop_sequence=[], # no stop sequence, will use eot token |
| 29 | +) |
| 30 | + |
| 31 | + |
| 32 | +# very specific task where there are no precise outputs but instead we test if the format obeys rules |
| 33 | +def ifeval_prompt(line, task_name: str = None): |
| 34 | + return Doc( |
| 35 | + task_name=task_name, |
| 36 | + query=line["prompt"], |
| 37 | + choices=[""], |
| 38 | + gold_index=0, |
| 39 | + instruction="", |
| 40 | + specific={"instructions_id_list": line["instruction_id_list"], "kwargs": line["kwargs"]}, |
| 41 | + ) |
| 42 | + |
| 43 | + |
| 44 | +submetric_names = [ |
| 45 | + "prompt_level_strict_acc", |
| 46 | + "inst_level_strict_acc", |
| 47 | + "prompt_level_loose_acc", |
| 48 | + "inst_level_loose_acc", |
| 49 | +] |
| 50 | + |
| 51 | + |
| 52 | +def ifeval_metric(predictions: list[str], formatted_doc: Doc, **kwargs) -> dict: |
| 53 | + response = predictions[0] |
| 54 | + |
| 55 | + # Strict instructions |
| 56 | + instruction_list = formatted_doc.specific["instructions_id_list"] |
| 57 | + all_kwargs = formatted_doc.specific["kwargs"] |
| 58 | + prompt = formatted_doc.query |
| 59 | + |
| 60 | + # Loose instructions |
| 61 | + r = response.split("\n") |
| 62 | + response_remove_first = "\n".join(r[1:]).strip() |
| 63 | + response_remove_last = "\n".join(r[:-1]).strip() |
| 64 | + response_remove_both = "\n".join(r[1:-1]).strip() |
| 65 | + revised_response = response.replace("*", "") |
| 66 | + revised_response_remove_first = response_remove_first.replace("*", "") |
| 67 | + revised_response_remove_last = response_remove_last.replace("*", "") |
| 68 | + revised_response_remove_both = response_remove_both.replace("*", "") |
| 69 | + all_responses = [ |
| 70 | + response, |
| 71 | + revised_response, |
| 72 | + response_remove_first, |
| 73 | + response_remove_last, |
| 74 | + response_remove_both, |
| 75 | + revised_response_remove_first, |
| 76 | + revised_response_remove_last, |
| 77 | + revised_response_remove_both, |
| 78 | + ] |
| 79 | + |
| 80 | + is_following_list_strict = [] |
| 81 | + is_following_list_loose = [] |
| 82 | + |
| 83 | + for index, instruction_id in enumerate(instruction_list): |
| 84 | + instruction_cls = instructions_registry.INSTRUCTION_DICT[instruction_id] |
| 85 | + instruction = instruction_cls(instruction_id) |
| 86 | + |
| 87 | + # Remove None values from kwargs to avoid unexpected keyword argument errors in build_description method. |
| 88 | + task_kwargs = {k: v for k, v in all_kwargs[index].items() if v} |
| 89 | + instruction.build_description(**task_kwargs) |
| 90 | + args = instruction.get_instruction_args() |
| 91 | + if args and "prompt" in args: |
| 92 | + instruction.build_description(prompt=prompt) |
| 93 | + |
| 94 | + # Strict |
| 95 | + if response.strip() and instruction.check_following(response): |
| 96 | + is_following_list_strict.append(True) |
| 97 | + else: |
| 98 | + is_following_list_strict.append(False) |
| 99 | + |
| 100 | + # Loose |
| 101 | + is_following = False |
| 102 | + for r in all_responses: |
| 103 | + if r.strip() and instruction.check_following(r): |
| 104 | + is_following = True |
| 105 | + break |
| 106 | + |
| 107 | + is_following_list_loose.append(is_following) |
| 108 | + |
| 109 | + return { |
| 110 | + "prompt_level_strict_acc": int(all(is_following_list_strict)), |
| 111 | + "inst_level_strict_acc": is_following_list_strict, |
| 112 | + "prompt_level_loose_acc": int(all(is_following_list_loose)), |
| 113 | + "inst_level_loose_acc": is_following_list_loose, |
| 114 | + } |
| 115 | + |
| 116 | + |
| 117 | +def agg_inst_level_acc(items): |
| 118 | + flat_items = [item for sublist in items for item in sublist] |
| 119 | + inst_level_acc = sum(flat_items) / len(flat_items) |
| 120 | + return inst_level_acc |
| 121 | + |
| 122 | + |
| 123 | +ifeval_metrics = SampleLevelMetricGrouping( |
| 124 | + metric=submetric_names, |
| 125 | + higher_is_better={n: True for n in submetric_names}, |
| 126 | + category=MetricCategory.GENERATIVE, |
| 127 | + use_case=MetricUseCase.ACCURACY, |
| 128 | + sample_level_fn=ifeval_metric, |
| 129 | + corpus_level_fn={ |
| 130 | + "prompt_level_strict_acc": np.mean, |
| 131 | + "inst_level_strict_acc": agg_inst_level_acc, |
| 132 | + "prompt_level_loose_acc": np.mean, |
| 133 | + "inst_level_loose_acc": agg_inst_level_acc, |
| 134 | + }, |
| 135 | +) |
| 136 | + |
| 137 | + |
| 138 | +_TASKS = [ifeval] |
| 139 | + |
| 140 | +# Convert to dict for lighteval |
| 141 | +TASKS_TABLE = [task.as_dict() for task in _TASKS] |
| 142 | +extend_enum(Metrics, "ifeval_metric", ifeval_metrics) |
| 143 | + |
| 144 | +if __name__ == "__main__": |
| 145 | + # Adds the metric to the metric list! |
| 146 | + print(t["name"] for t in TASKS_TABLE) |
| 147 | + print(len(TASKS_TABLE)) |
0 commit comments