File tree Expand file tree Collapse file tree 4 files changed +13
-5
lines changed Expand file tree Collapse file tree 4 files changed +13
-5
lines changed Original file line number Diff line number Diff line change 3838 HF_HOME : " cache/models"
3939 HF_DATASETS_CACHE : " cache/datasets"
4040 run : | # PYTHONPATH="${PYTHONPATH}:src" HF_DATASETS_CACHE="cache/datasets" HF_HOME="cache/models"
41- python -m pytest --disable-pytest-warnings
41+ python -m pytest -x - -disable-pytest-warnings
4242 - name : Write cache
4343 uses : actions/cache@v4
4444 with :
Original file line number Diff line number Diff line change @@ -90,6 +90,8 @@ class GeneralConfigLogger:
9090 model_dtype : str = None
9191 model_size : str = None
9292
93+ generation_parameters : dict | None = None
94+
9395 # Nanotron config
9496 config : "Config" = None
9597
@@ -133,14 +135,16 @@ def log_args_info(
133135 self .job_id = job_id
134136 self .config = config
135137
136- def log_model_info (self , model_info : ModelInfo ) -> None :
138+ def log_model_info (self , generation_parameters : dict , model_info : ModelInfo ) -> None :
137139 """
138140 Logs the model information.
139141
140142 Args:
143+ model_config: the model config used to initalize the model.
141144 model_info (ModelInfo): Model information to be logged.
142145
143146 """
147+ self .generation_parameters = generation_parameters
144148 self .model_name = model_info .model_name
145149 self .model_sha = model_info .model_sha
146150 self .model_dtype = model_info .model_dtype
Original file line number Diff line number Diff line change 2727import re
2828import shutil
2929from contextlib import nullcontext
30- from dataclasses import dataclass , field
30+ from dataclasses import asdict , dataclass , field
3131from datetime import timedelta
3232from enum import Enum , auto
3333
@@ -156,7 +156,9 @@ def __init__(
156156 self .accelerator , self .parallel_context = self ._init_parallelism_manager ()
157157 self .model = self ._init_model (model_config , model )
158158
159- self .evaluation_tracker .general_config_logger .log_model_info (self .model .model_info )
159+ generation_parameters = asdict (model_config .generation_parameters ) if model_config else {}
160+
161+ self .evaluation_tracker .general_config_logger .log_model_info (generation_parameters , self .model .model_info )
160162 self ._init_tasks_and_requests (tasks = tasks )
161163 self ._init_random_seeds ()
162164 # Final results
Original file line number Diff line number Diff line change 2626
2727
2828def test_empty_requests ():
29- model_config = TransformersModelConfig ("hf-internal-testing/tiny-random-LlamaForCausalLM" )
29+ model_config = TransformersModelConfig (
30+ "hf-internal-testing/tiny-random-LlamaForCausalLM" , model_parallel = False , revision = "main"
31+ )
3032 model : TransformersModel = load_model (config = model_config , env_config = EnvConfig (cache_dir = "." ))
3133
3234 assert model .loglikelihood ([]) == []
You can’t perform that action at this time.
0 commit comments