Skip to content

Commit 2dddbc5

Browse files
committed
Add backwards compatibility for transformers v4.57
1 parent b5115a0 commit 2dddbc5

File tree

4 files changed

+9
-4
lines changed

4 files changed

+9
-4
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ dependencies = [
3434
"deprecated",
3535
"numpy",
3636
"datasets",
37-
"transformers>=5.0.0",
37+
"transformers>=4.57.0",
3838
"ninja",
3939
"numba>=0.62.0",
4040
"rich",

research_scratch/fsdp1_dummy_script.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def wrap_model_with_fsdp1(model: torch.nn.Module) -> FSDP:
5353
# Get transformer layer class for auto-wrap
5454
auto_wrap_policy = None
5555
if hasattr(model, "_no_split_modules") and model._no_split_modules:
56-
layer_name = model._no_split_modules[0]
56+
layer_name = next(iter(model._no_split_modules))
5757
layer_cls = get_module_class_from_name(model, layer_name)
5858
if layer_cls:
5959
auto_wrap_policy = functools.partial(transformer_auto_wrap_policy, transformer_layer_cls={layer_cls})

research_scratch/fsdp1_wrapper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def wrap_fsdp1(model: torch.nn.Module) -> torch.nn.Module:
3838
model.gradient_checkpointing_enable()
3939

4040
# Determine the block class to auto-wrap (first no-split module)
41-
block_name = model._no_split_modules[0]
41+
block_name = next(iter(model._no_split_modules))
4242
block_cls = get_module_class_from_name(model, block_name)
4343
if block_cls is None:
4444
raise ValueError(f"Could not find module class named {block_name}")

scripts/process_data.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
11
from pathlib import Path
22

33
import numpy as np
4+
import transformers
45
import typer
56
from datasets import load_dataset
67
from transformers import AutoTokenizer
78

9+
# Transformers v5 renamed 'additional_special_tokens' to 'extra_special_tokens'
10+
_TRANSFORMERS_V5 = int(transformers.__version__.split(".")[0]) >= 5
11+
_SPECIAL_TOKENS_KEY = "extra_special_tokens" if _TRANSFORMERS_V5 else "additional_special_tokens"
12+
813
app = typer.Typer()
914

1015
# Import from main codebase instead of duplicating
@@ -235,7 +240,7 @@ def process_data(
235240
):
236241
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
237242
assistant_tk_ids, user_tk_ids = infer_special_token_sequences(tokenizer)
238-
tokenizer.add_special_tokens({"extra_special_tokens": [string_for_printing_masks]})
243+
tokenizer.add_special_tokens({_SPECIAL_TOKENS_KEY: [string_for_printing_masks]})
239244
string_for_printing_masks_tk = tokenizer.encode(string_for_printing_masks, add_special_tokens=False)[0]
240245

241246
dataset = load_dataset("json", data_files=input_jsonl, split="train")

0 commit comments

Comments
 (0)