File tree Expand file tree Collapse file tree 4 files changed +9
-4
lines changed
Expand file tree Collapse file tree 4 files changed +9
-4
lines changed Original file line number Diff line number Diff line change @@ -34,7 +34,7 @@ dependencies = [
3434 " deprecated" ,
3535 " numpy" ,
3636 " datasets" ,
37- " transformers>=5.0 .0" ,
37+ " transformers>=4.57 .0" ,
3838 " ninja" ,
3939 " numba>=0.62.0" ,
4040 " rich" ,
Original file line number Diff line number Diff line change @@ -53,7 +53,7 @@ def wrap_model_with_fsdp1(model: torch.nn.Module) -> FSDP:
5353 # Get transformer layer class for auto-wrap
5454 auto_wrap_policy = None
5555 if hasattr (model , "_no_split_modules" ) and model ._no_split_modules :
56- layer_name = model ._no_split_modules [ 0 ]
56+ layer_name = next ( iter ( model ._no_split_modules ))
5757 layer_cls = get_module_class_from_name (model , layer_name )
5858 if layer_cls :
5959 auto_wrap_policy = functools .partial (transformer_auto_wrap_policy , transformer_layer_cls = {layer_cls })
Original file line number Diff line number Diff line change @@ -38,7 +38,7 @@ def wrap_fsdp1(model: torch.nn.Module) -> torch.nn.Module:
3838 model .gradient_checkpointing_enable ()
3939
4040 # Determine the block class to auto-wrap (first no-split module)
41- block_name = model ._no_split_modules [ 0 ]
41+ block_name = next ( iter ( model ._no_split_modules ))
4242 block_cls = get_module_class_from_name (model , block_name )
4343 if block_cls is None :
4444 raise ValueError (f"Could not find module class named { block_name } " )
Original file line number Diff line number Diff line change 11from pathlib import Path
22
33import numpy as np
4+ import transformers
45import typer
56from datasets import load_dataset
67from transformers import AutoTokenizer
78
9+ # Transformers v5 renamed 'additional_special_tokens' to 'extra_special_tokens'
10+ _TRANSFORMERS_V5 = int (transformers .__version__ .split ("." )[0 ]) >= 5
11+ _SPECIAL_TOKENS_KEY = "extra_special_tokens" if _TRANSFORMERS_V5 else "additional_special_tokens"
12+
813app = typer .Typer ()
914
1015# Import from main codebase instead of duplicating
@@ -235,7 +240,7 @@ def process_data(
235240):
236241 tokenizer = AutoTokenizer .from_pretrained (model_name_or_path )
237242 assistant_tk_ids , user_tk_ids = infer_special_token_sequences (tokenizer )
238- tokenizer .add_special_tokens ({"extra_special_tokens" : [string_for_printing_masks ]})
243+ tokenizer .add_special_tokens ({_SPECIAL_TOKENS_KEY : [string_for_printing_masks ]})
239244 string_for_printing_masks_tk = tokenizer .encode (string_for_printing_masks , add_special_tokens = False )[0 ]
240245
241246 dataset = load_dataset ("json" , data_files = input_jsonl , split = "train" )
You can’t perform that action at this time.
0 commit comments