Skip to content

Commit 70f7f9e

Browse files
committed
removed EnvConfig for nanotron
1 parent fac137e commit 70f7f9e

File tree

3 files changed

+12
-26
lines changed

3 files changed

+12
-26
lines changed

src/lighteval/main_nanotron.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,6 @@
2626
from typer import Option
2727
from typing_extensions import Annotated
2828

29-
30-
CACHE_DIR: str = os.getenv("HF_HOME", "/scratch")
31-
3229
HELP_PANEL_NAME_1 = "Common Parameters"
3330
HELP_PANEL_NAME_2 = "Logging Parameters"
3431
HELP_PANEL_NAME_3 = "Debug Parameters"
@@ -42,8 +39,7 @@ def nanotron(
4239
checkpoint_config_path: Annotated[
4340
str, Option(help="Path to the nanotron checkpoint YAML or python config file, potentially on s3.")
4441
],
45-
lighteval_config_path: Annotated[str, Option(help="Path to a YAML config to be used for the evaluation.")],
46-
cache_dir: Annotated[str, Option(help="Cache directory for datasets and models.")] = CACHE_DIR,
42+
lighteval_config_path: Annotated[str, Option(help="Path to a YAML config to be used for the evaluation.")]
4743
):
4844
"""
4945
Evaluate models using nanotron as backend.
@@ -54,9 +50,6 @@ def nanotron(
5450
from lighteval.logging.evaluation_tracker import EvaluationTracker
5551
from lighteval.pipeline import ParallelismManager, Pipeline, PipelineParameters
5652
from lighteval.utils.imports import NO_NANOTRON_ERROR_MSG, is_nanotron_available
57-
from lighteval.utils.utils import EnvConfig
58-
59-
env_config = EnvConfig(token=os.getenv("HF_TOKEN"), cache_dir=cache_dir)
6053

6154
if not is_nanotron_available():
6255
raise ImportError(NO_NANOTRON_ERROR_MSG)
@@ -75,7 +68,7 @@ def nanotron(
7568
# We are getting an type error, because the get_config_from_file is not correctly typed,
7669
lighteval_config: LightEvalConfig = get_config_from_file(lighteval_config_path, config_class=LightEvalConfig) # type: ignore
7770
nanotron_config = FullNanotronConfig(lighteval_config, model_config)
78-
71+
7972
evaluation_tracker = EvaluationTracker(
8073
output_dir=lighteval_config.logging.output_dir,
8174
hub_results_org=lighteval_config.logging.results_org,
@@ -89,12 +82,11 @@ def nanotron(
8982

9083
pipeline_parameters = PipelineParameters(
9184
launcher_type=ParallelismManager.NANOTRON,
92-
env_config=env_config,
9385
job_id=os.environ.get("SLURM_JOB_ID", 0),
9486
nanotron_checkpoint_path=checkpoint_config_path,
9587
dataset_loading_processes=lighteval_config.tasks.dataset_loading_processes,
9688
custom_tasks_directory=lighteval_config.tasks.custom_tasks,
97-
override_batch_size=lighteval_config.batch_size,
89+
# override_batch_size=lighteval_config.batch_size,
9890
num_fewshot_seeds=1,
9991
max_samples=lighteval_config.tasks.max_samples,
10092
use_chat_template=False,

src/lighteval/models/nanotron/nanotron_model.py

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@
5656
)
5757
from lighteval.utils.imports import is_nanotron_available
5858
from lighteval.utils.parallelism import find_executable_batch_size
59-
from lighteval.utils.utils import EnvConfig, as_list
59+
from lighteval.utils.utils import as_list
6060

6161

6262
logger = logging.getLogger(__name__)
@@ -101,7 +101,6 @@ def __init__(
101101
trust_remote_code: bool = False,
102102
debug_one_layer_model: bool = False,
103103
model_class: Optional[Type] = None,
104-
env_config: EnvConfig = None,
105104
):
106105
"""Initializes a nanotron model for evaluation.
107106
Args:
@@ -138,7 +137,6 @@ def __init__(
138137
self._add_special_tokens = add_special_tokens
139138
self._tokenizer = self._create_auto_tokenizer(
140139
pretrained=tokenizer.tokenizer_name_or_path,
141-
env_config=env_config,
142140
trust_remote_code=trust_remote_code,
143141
)
144142
self._tokenizer.model_max_length = self.max_length
@@ -230,23 +228,18 @@ def _create_auto_tokenizer(
230228
*,
231229
pretrained: str,
232230
tokenizer: Optional[str] = None,
233-
env_config: EnvConfig = None,
234231
trust_remote_code: bool = False,
235232
) -> transformers.PreTrainedTokenizer:
236233
"""Returns a pre-trained tokenizer from a pre-trained tokenizer configuration."""
237234

238235
try:
239236
tokenizer = AutoTokenizer.from_pretrained(
240237
pretrained if tokenizer is None else tokenizer,
241-
cache_dir=env_config.cache_dir,
242-
token=env_config.token,
243238
trust_remote_code=trust_remote_code,
244239
)
245240
except RecursionError:
246241
tokenizer = AutoTokenizer.from_pretrained(
247242
pretrained if tokenizer is None else tokenizer,
248-
cache_dir=env_config.cache_dir,
249-
token=env_config.token,
250243
unk_token="<unk>",
251244
trust_remote_code=trust_remote_code,
252245
)
@@ -711,14 +704,14 @@ def _loglikelihood_single_token(
711704
inputs, padding_length=max_context, max_context=max_context, full_attention_masks=True
712705
)
713706
# batched_inputs, batch_attention, input_lengths, truncated, padded
714-
715-
out = self.model(input_ids=batch_model.input_ids, input_mask=batch_model.input_mask)
707+
position_ids = torch.arange(batch_model.input_ids.shape[1], device=self.device, dtype=torch.int32).unsqueeze(0).repeat(batch_model.input_ids.shape[0], 1)
708+
out = self.model(input_ids=batch_model.input_ids, position_ids=position_ids)
716709

717710
if dist.get_rank(self.parallel_context.pp_pg) == self.output_pp_rank:
718711
# This process got outputs
719712

720-
# Gather all the output across TP
721-
out = out.transpose(0, 1).contiguous() # [batch, seq_length, vocab]
713+
# Gather all the output accross TP
714+
out = out.view(*batch_model.input_ids.shape, -1).contiguous() # [batch, seq_length, vocab]
722715

723716
gathered_out = [torch.zeros_like(out) for _ in range(self.parallel_context.tp_pg.size())]
724717
dist.all_gather(gathered_out, out, group=self.parallel_context.tp_pg, async_op=False)
@@ -944,7 +937,8 @@ def _loglikelihood_tokens(
944937
)
945938
# batched_inputs, batch_attention, input_lengths, truncated, padded
946939
with torch.no_grad():
947-
out = self.model(input_ids=batch_model.input_ids, input_mask=batch_model.input_mask)
940+
position_ids = torch.arange(batch_model.input_ids.shape[1], device=self.device, dtype=torch.int32).unsqueeze(0).repeat(batch_model.input_ids.shape[0], 1)
941+
out = self.model(input_ids=batch_model.input_ids, position_ids=position_ids)
948942

949943
if dist.get_rank(self.parallel_context.pp_pg) == self.output_pp_rank:
950944
# This process got outputs
@@ -954,7 +948,7 @@ def _loglikelihood_tokens(
954948
dist.all_gather(gathered_out, out, group=self.parallel_context.tp_pg, async_op=False)
955949
out = torch.cat(gathered_out, dim=-1)
956950

957-
out = out.transpose(0, 1) # [batch, seq_length, vocab]
951+
out = out.view(*batch_model.input_ids.shape, -1) # [batch, seq_length, vocab]
958952
multi_logits = F.log_softmax(out, dim=-1) # [batch, padding_length, vocab]
959953

960954
logits_sum = []

src/lighteval/pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ def __init__(
155155
self.accelerator, self.parallel_context = self._init_parallelism_manager()
156156
self.model = self._init_model(model_config, model)
157157

158-
generation_parameters = model_config.generation_parameters.model_dump() if model_config else {}
158+
generation_parameters = model_config.generation_parameters.model_dump() if model_config and hasattr(model_config, "generation_parameters") else {}
159159

160160
self.evaluation_tracker.general_config_logger.log_model_info(generation_parameters, self.model.model_info)
161161
self._init_random_seeds()

0 commit comments

Comments
 (0)