| 
 | 1 | +# MIT License  | 
 | 2 | + | 
 | 3 | +# Copyright (c) 2024 The HuggingFace Team  | 
 | 4 | + | 
 | 5 | +# Permission is hereby granted, free of charge, to any person obtaining a copy  | 
 | 6 | +# of this software and associated documentation files (the "Software"), to deal  | 
 | 7 | +# in the Software without restriction, including without limitation the rights  | 
 | 8 | +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell  | 
 | 9 | +# copies of the Software, and to permit persons to whom the Software is  | 
 | 10 | +# furnished to do so, subject to the following conditions:  | 
 | 11 | + | 
 | 12 | +# The above copyright notice and this permission notice shall be included in all  | 
 | 13 | +# copies or substantial portions of the Software.  | 
 | 14 | + | 
 | 15 | +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR  | 
 | 16 | +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,  | 
 | 17 | +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE  | 
 | 18 | +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER  | 
 | 19 | +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,  | 
 | 20 | +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE  | 
 | 21 | +# SOFTWARE.  | 
 | 22 | + | 
 | 23 | +import json  | 
 | 24 | +import os  | 
 | 25 | +import tempfile  | 
 | 26 | +from datetime import datetime  | 
 | 27 | +from pathlib import Path  | 
 | 28 | + | 
 | 29 | +import pytest  | 
 | 30 | +from datasets import Dataset  | 
 | 31 | +from huggingface_hub import HfApi  | 
 | 32 | + | 
 | 33 | +from lighteval.logging.evaluation_tracker import EvaluationTracker  | 
 | 34 | +from lighteval.logging.info_loggers import DetailsLogger  | 
 | 35 | + | 
 | 36 | +# ruff: noqa  | 
 | 37 | +from tests.fixtures import TESTING_EMPTY_HF_ORG_ID, testing_empty_hf_org_id  | 
 | 38 | + | 
 | 39 | + | 
 | 40 | +@pytest.fixture  | 
 | 41 | +def mock_evaluation_tracker(request):  | 
 | 42 | +    passed_params = {}  | 
 | 43 | +    if request.keywords.get("evaluation_tracker"):  | 
 | 44 | +        passed_params = request.keywords["evaluation_tracker"].kwargs  | 
 | 45 | + | 
 | 46 | +    with tempfile.TemporaryDirectory() as temp_dir:  | 
 | 47 | +        kwargs = {  | 
 | 48 | +            "output_dir": temp_dir,  | 
 | 49 | +            "save_details": passed_params.get("save_details", False),  | 
 | 50 | +            "push_to_hub": passed_params.get("push_to_hub", False),  | 
 | 51 | +            "push_to_tensorboard": passed_params.get("push_to_tensorboard", False),  | 
 | 52 | +            "hub_results_org": passed_params.get("hub_results_org", ""),  | 
 | 53 | +        }  | 
 | 54 | +        tracker = EvaluationTracker(**kwargs)  | 
 | 55 | +        tracker.general_config_logger.model_name = "test_model"  | 
 | 56 | +        yield tracker  | 
 | 57 | + | 
 | 58 | + | 
 | 59 | +@pytest.fixture  | 
 | 60 | +def mock_datetime(monkeypatch):  | 
 | 61 | +    mock_date = datetime(2023, 1, 1, 12, 0, 0)  | 
 | 62 | + | 
 | 63 | +    class MockDatetime:  | 
 | 64 | +        @classmethod  | 
 | 65 | +        def now(cls):  | 
 | 66 | +            return mock_date  | 
 | 67 | + | 
 | 68 | +        @classmethod  | 
 | 69 | +        def fromisoformat(cls, date_string: str):  | 
 | 70 | +            return mock_date  | 
 | 71 | + | 
 | 72 | +    monkeypatch.setattr("lighteval.logging.evaluation_tracker.datetime", MockDatetime)  | 
 | 73 | +    return mock_date  | 
 | 74 | + | 
 | 75 | + | 
 | 76 | +def test_results_logging(mock_evaluation_tracker: EvaluationTracker):  | 
 | 77 | +    task_metrics = {  | 
 | 78 | +        "task1": {"accuracy": 0.8, "f1": 0.75},  | 
 | 79 | +        "task2": {"precision": 0.9, "recall": 0.85},  | 
 | 80 | +    }  | 
 | 81 | +    mock_evaluation_tracker.metrics_logger.metric_aggregated = task_metrics  | 
 | 82 | + | 
 | 83 | +    mock_evaluation_tracker.save()  | 
 | 84 | + | 
 | 85 | +    results_dir = Path(mock_evaluation_tracker.output_dir) / "results" / "test_model"  | 
 | 86 | +    assert results_dir.exists()  | 
 | 87 | + | 
 | 88 | +    result_files = list(results_dir.glob("results_*.json"))  | 
 | 89 | +    assert len(result_files) == 1  | 
 | 90 | + | 
 | 91 | +    with open(result_files[0], "r") as f:  | 
 | 92 | +        saved_results = json.load(f)  | 
 | 93 | + | 
 | 94 | +    assert "results" in saved_results  | 
 | 95 | +    assert saved_results["results"] == task_metrics  | 
 | 96 | +    assert saved_results["config_general"]["model_name"] == "test_model"  | 
 | 97 | + | 
 | 98 | + | 
 | 99 | +@pytest.mark.evaluation_tracker(save_details=True)  | 
 | 100 | +def test_details_logging(mock_evaluation_tracker, mock_datetime):  | 
 | 101 | +    task_details = {  | 
 | 102 | +        "task1": [DetailsLogger.CompiledDetail(truncated=10, padded=5)],  | 
 | 103 | +        "task2": [DetailsLogger.CompiledDetail(truncated=20, padded=10)],  | 
 | 104 | +    }  | 
 | 105 | +    mock_evaluation_tracker.details_logger.details = task_details  | 
 | 106 | + | 
 | 107 | +    mock_evaluation_tracker.save()  | 
 | 108 | + | 
 | 109 | +    date_id = mock_datetime.isoformat().replace(":", "-")  | 
 | 110 | +    details_dir = Path(mock_evaluation_tracker.output_dir) / "details" / "test_model" / date_id  | 
 | 111 | +    assert details_dir.exists()  | 
 | 112 | + | 
 | 113 | +    for task in ["task1", "task2"]:  | 
 | 114 | +        file_path = details_dir / f"details_{task}_{date_id}.parquet"  | 
 | 115 | +        dataset = Dataset.from_parquet(str(file_path))  | 
 | 116 | +        assert len(dataset) == 1  | 
 | 117 | +        assert int(dataset[0]["truncated"]) == task_details[task][0].truncated  | 
 | 118 | +        assert int(dataset[0]["padded"]) == task_details[task][0].padded  | 
 | 119 | + | 
 | 120 | + | 
 | 121 | +@pytest.mark.evaluation_tracker(save_details=False)  | 
 | 122 | +def test_no_details_output(mock_evaluation_tracker: EvaluationTracker):  | 
 | 123 | +    mock_evaluation_tracker.save()  | 
 | 124 | + | 
 | 125 | +    details_dir = Path(mock_evaluation_tracker.output_dir) / "details" / "test_model"  | 
 | 126 | +    assert not details_dir.exists()  | 
 | 127 | + | 
 | 128 | + | 
 | 129 | +@pytest.mark.evaluation_tracker(push_to_hub=True, hub_results_org=TESTING_EMPTY_HF_ORG_ID)  | 
 | 130 | +def test_push_to_hub_works(testing_empty_hf_org_id, mock_evaluation_tracker: EvaluationTracker, mock_datetime):  | 
 | 131 | +    # Prepare the dummy data  | 
 | 132 | +    task_metrics = {  | 
 | 133 | +        "task1": {"accuracy": 0.8, "f1": 0.75},  | 
 | 134 | +        "task2": {"precision": 0.9, "recall": 0.85},  | 
 | 135 | +    }  | 
 | 136 | +    mock_evaluation_tracker.metrics_logger.metric_aggregated = task_metrics  | 
 | 137 | + | 
 | 138 | +    task_details = {  | 
 | 139 | +        "task1": [DetailsLogger.CompiledDetail(truncated=10, padded=5)],  | 
 | 140 | +        "task2": [DetailsLogger.CompiledDetail(truncated=20, padded=10)],  | 
 | 141 | +    }  | 
 | 142 | +    mock_evaluation_tracker.details_logger.details = task_details  | 
 | 143 | +    mock_evaluation_tracker.save()  | 
 | 144 | + | 
 | 145 | +    # Verify using HfApi  | 
 | 146 | +    api = HfApi()  | 
 | 147 | + | 
 | 148 | +    # Check if repo exists and it's private  | 
 | 149 | +    expected_repo_id = f"{testing_empty_hf_org_id}/details_test_model_private"  | 
 | 150 | +    assert api.repo_exists(repo_id=expected_repo_id, repo_type="dataset")  | 
 | 151 | +    assert api.repo_info(repo_id=expected_repo_id, repo_type="dataset").private  | 
 | 152 | + | 
 | 153 | +    repo_files = api.list_repo_files(repo_id=expected_repo_id, repo_type="dataset")  | 
 | 154 | +    # Check if README.md exists  | 
 | 155 | +    assert any(file == "README.md" for file in repo_files)  | 
 | 156 | + | 
 | 157 | +    # Check that both results files were uploaded  | 
 | 158 | +    result_files = [file for file in repo_files if file.startswith("results_")]  | 
 | 159 | +    assert len(result_files) == 2  | 
 | 160 | +    assert len([file for file in result_files if file.endswith(".json")]) == 1  | 
 | 161 | +    assert len([file for file in result_files if file.endswith(".parquet")]) == 1  | 
 | 162 | + | 
 | 163 | +    # Check that the details dataset was uploaded  | 
 | 164 | +    details_files = [file for file in repo_files if "details_" in file and file.endswith(".parquet")]  | 
 | 165 | +    assert len(details_files) == 2  | 
0 commit comments