Skip to content

Commit e69ca5e

Browse files
add image hashing and LMEVAL_HASHMM envar (#2973)
* add image hashing * remove unused params decription * use `LMEVAL_HASHMM` (defualt '1') to save raw images --------- Co-authored-by: Baber <[email protected]>
1 parent 0e96cd1 commit e69ca5e

File tree

2 files changed

+59
-1
lines changed

2 files changed

+59
-1
lines changed

lm_eval/evaluator.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import itertools
22
import json
33
import logging
4+
import os
45
import random
56
import time
67
from collections import defaultdict
@@ -29,6 +30,7 @@
2930
from lm_eval.tasks import TaskManager, get_task_dict
3031
from lm_eval.utils import (
3132
handle_non_serializable,
33+
hash_dict_images,
3234
hash_string,
3335
positional_deprecated,
3436
setup_logging,
@@ -140,7 +142,6 @@ def simple_evaluate(
140142
Random seed for fewshot sampler random generator. If set to None, the seed of generator will be set to None.
141143
:param metadata: dict
142144
Additional metadata to be added to the task manager. Will get passed to the download function of the task.
143-
144145
return
145146
Dictionary of results
146147
"""
@@ -747,6 +748,12 @@ def evaluate(
747748
},
748749
}
749750
if log_samples:
751+
# default: hash images
752+
samples = (
753+
hash_dict_images(samples)
754+
if os.environ.get("LMEVAL_HASHMM", "1") != "0"
755+
else samples
756+
)
750757
results_dict["samples"] = dict(samples)
751758

752759
return results_dict

lm_eval/utils.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -550,3 +550,54 @@ def weighted_f1_score(items):
550550
preds = unzipped_list[1]
551551
fscore = f1_score(golds, preds, average="weighted")
552552
return fscore
553+
554+
555+
def convert_pil_to_hash(value):
556+
from io import BytesIO
557+
558+
img_bytes = BytesIO()
559+
value.save(img_bytes, format="PNG")
560+
return hashlib.sha256(str(img_bytes).encode()).hexdigest()
561+
562+
563+
def convert_bytes_to_hash(value):
564+
return hashlib.sha256(str(value).encode()).hexdigest()
565+
566+
567+
def hash_dict_images(data_dict):
568+
"""
569+
Create a deep copy of `data_dict` where all bytes and PIL.Image.Image values
570+
are replaced by their respective hashes using the provided converter functions.
571+
572+
Parameters:
573+
data_dict (dict): The input dictionary with arbitrary nesting of dicts and lists.
574+
575+
Returns:
576+
dict: A new dictionary with the same structure as `data_dict`, but with all
577+
bytes and PIL.Image.Image objects replaced by their hashes.
578+
"""
579+
from PIL import Image
580+
581+
def _process_value(value):
582+
# Bytes -> hash
583+
if isinstance(value, (bytes, bytearray)):
584+
return convert_bytes_to_hash(value)
585+
# PIL Image -> hash
586+
if isinstance(value, Image.Image):
587+
return convert_pil_to_hash(value)
588+
# Nested dictionary -> recurse
589+
if isinstance(value, dict):
590+
return {k: _process_value(v) for k, v in value.items()}
591+
# List or tuple -> recurse, preserving type
592+
if isinstance(value, list):
593+
return [_process_value(v) for v in value]
594+
if isinstance(value, tuple):
595+
return tuple(_process_value(v) for v in value)
596+
# Other types remain unchanged
597+
return value
598+
599+
# Ensure the top-level is a dict
600+
if not isinstance(data_dict, dict):
601+
raise TypeError("Input must be a dictionary")
602+
603+
return {key: _process_value(val) for key, val in data_dict.items()}

0 commit comments

Comments
 (0)