Skip to content

Commit 3f05c67

Browse files
authored
deps: cap transformers at 4.40.2 (#218)
* deps: pin transformers below v4.41d Signed-off-by: Anh-Uong <[email protected]> * remove deprecated requirements.yaml Signed-off-by: Anh-Uong <[email protected]> * update unit tests with old evaluation_strategy flag Signed-off-by: Anh-Uong <[email protected]> * set transformers upper bound to 4.40.2 - update eval flag in docs Signed-off-by: Anh-Uong <[email protected]> --------- Signed-off-by: Anh-Uong <[email protected]>
1 parent 0949699 commit 3f05c67

File tree

3 files changed

+4
-4
lines changed

3 files changed

+4
-4
lines changed

examples/prompt_tuning_twitter_complaints/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ tuning/sft_trainer.py \
5151
--per_device_train_batch_size 1 \
5252
--per_device_eval_batch_size 1 \
5353
--gradient_accumulation_steps 1 \
54-
--eval_strategy "no" \
54+
--evaluation_strategy "no" \
5555
--save_strategy "epoch" \
5656
--learning_rate 1e-5 \
5757
--weight_decay 0. \

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ classifiers=[
2828
dependencies = [
2929
"numpy>=1.26.4,<2.0",
3030
"accelerate>=0.20.3,<0.40",
31-
"transformers>=4.34.1,<5.0,!=4.38.2",
31+
"transformers>=4.34.1,<=4.40.2,!=4.38.2",
3232
"torch>=2.2.0,<3.0",
3333
"sentencepiece>=0.1.99,<0.3",
3434
"tokenizers>=0.13.3,<1.0",

tests/test_sft_trainer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,7 @@ def test_run_causallm_pt_with_validation():
304304
with tempfile.TemporaryDirectory() as tempdir:
305305
train_args = copy.deepcopy(TRAIN_ARGS)
306306
train_args.output_dir = tempdir
307-
train_args.eval_strategy = "epoch"
307+
train_args.evaluation_strategy = "epoch"
308308
data_args = copy.deepcopy(DATA_ARGS)
309309
data_args.validation_data_path = TWITTER_COMPLAINTS_DATA
310310

@@ -317,7 +317,7 @@ def test_run_causallm_pt_with_validation_data_formatting():
317317
with tempfile.TemporaryDirectory() as tempdir:
318318
train_args = copy.deepcopy(TRAIN_ARGS)
319319
train_args.output_dir = tempdir
320-
train_args.eval_strategy = "epoch"
320+
train_args.evaluation_strategy = "epoch"
321321
data_args = copy.deepcopy(DATA_ARGS)
322322
data_args.validation_data_path = TWITTER_COMPLAINTS_DATA
323323
data_args.dataset_text_field = None

0 commit comments

Comments
 (0)