Skip to content

Commit 20bb2a4

Browse files
committed
merged rebased pr huggingface#656
2 parents 70f7f9e + eca97eb commit 20bb2a4

File tree

5 files changed

+29
-4
lines changed

5 files changed

+29
-4
lines changed

src/lighteval/config/lighteval_config.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,3 +101,11 @@ class LightEvalConfig:
101101
class FullNanotronConfig:
102102
lighteval_config: LightEvalConfig
103103
nanotron_config: "Config"
104+
105+
@property
106+
def generation_parameters(self):
107+
# Return the generation parameters from the lighteval config
108+
# or create default generation parameters if none are set
109+
if self.lighteval_config.generation:
110+
return self.lighteval_config.generation
111+
return GenerationArgs()

src/lighteval/main_nanotron.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,9 @@ def nanotron(
4545
Evaluate models using nanotron as backend.
4646
"""
4747
from nanotron.config import Config, get_config_from_file
48+
from nanotron.config.parallelism_config import ParallelismArgs
4849

49-
from lighteval.config.lighteval_config import FullNanotronConfig, LightEvalConfig
50+
from lighteval.config.lighteval_config import FullNanotronConfig, LightEvalConfig, LightEvalLoggingArgs, LightEvalTasksArgs
5051
from lighteval.logging.evaluation_tracker import EvaluationTracker
5152
from lighteval.pipeline import ParallelismManager, Pipeline, PipelineParameters
5253
from lighteval.utils.imports import NO_NANOTRON_ERROR_MSG, is_nanotron_available
@@ -64,6 +65,13 @@ def nanotron(
6465
skip_unused_config_keys=True,
6566
skip_null_keys=True,
6667
)
68+
model_config = get_config_from_file(
69+
checkpoint_config_path,
70+
config_class=Config,
71+
model_config_class=None,
72+
skip_unused_config_keys=True,
73+
skip_null_keys=True,
74+
)
6775

6876
# We are getting an type error, because the get_config_from_file is not correctly typed,
6977
lighteval_config: LightEvalConfig = get_config_from_file(lighteval_config_path, config_class=LightEvalConfig) # type: ignore

src/lighteval/models/nanotron/nanotron_model.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -336,7 +336,14 @@ def tok_decode(self, tokens: torch.LongTensor) -> List[str]:
336336
return self.tokenizer.batch_decode(tokens, skip_special_tokens=True)
337337

338338
def _model_call(self, inputs: torch.Tensor) -> torch.Tensor:
339-
return self.model(inputs)
339+
position_ids = (
340+
torch.arange(
341+
inputs.shape[1], device=inputs.device, dtype=torch.int32
342+
)
343+
.unsqueeze(0)
344+
.repeat(inputs.shape[0], 1)
345+
)
346+
return self.model(inputs, position_ids)
340347

341348
def homogeneize_ending_conditions(self, ending_condition: tuple | dict | list | str) -> tuple[list, int]:
342349
"""Ending conditions are submitted in several possible formats.

src/lighteval/pipeline.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -186,12 +186,12 @@ def _init_parallelism_manager(self):
186186
def _init_model(self, model_config, model):
187187
logger.info("--- LOADING MODEL ---")
188188
if model_config is not None:
189-
if self.parallel_context:
189+
if self.parallel_context:
190190
return NanotronLightevalModel(
191191
checkpoint_path=os.path.dirname(self.pipeline_parameters.nanotron_checkpoint_path)
192192
if self.pipeline_parameters.nanotron_checkpoint_path
193193
else "",
194-
nanotron_config=self.model_config,
194+
nanotron_config=model_config,
195195
parallel_context=self.parallel_context,
196196
debug_one_layer_model=False,
197197
model_class=None,

src/lighteval/tasks/lighteval_task.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ class LightevalTaskConfig:
107107
few_shots_select: Optional[str] = None
108108

109109
# Generation args
110+
output_regex: Optional[str] = None
110111
generation_size: Optional[int] = None
111112
generation_grammar: Optional[TextGenerationInputGrammarType] = None
112113
stop_sequence: Optional[ListLike[str]] = None
@@ -120,6 +121,7 @@ class LightevalTaskConfig:
120121
must_remove_duplicate_docs: bool = False
121122

122123
version: int = 0
124+
frozen: bool = False
123125

124126
def __post_init__(self):
125127
# If we got a Metrics enums instead of a Metric, we convert

0 commit comments

Comments
 (0)