generated from fastai/nbdev_template
-
Notifications
You must be signed in to change notification settings - Fork 2.5k
[GKD] Buffer Implementation for Distillation Trainer #5137
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
cmpatino
wants to merge
28
commits into
huggingface:main
Choose a base branch
from
cmpatino:kd-buffering
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+631
−267
Open
Changes from 25 commits
Commits
Show all changes
28 commits
Select commit
Hold shift + click to select a range
719c644
Implement buffer for GOLDTrainer
cmpatino 904378b
Clean up code from KD buffer
cmpatino 6a2ece5
Test scripts for trial run
cmpatino ee07aec
Apply fixes to the Liger loss setting
cmpatino a3fd2af
Remove test scripts
cmpatino b0669d9
Handle config parameters better in gold script
cmpatino b0c4f3e
Upload provisional SLURM script for GOLD
cmpatino 602e564
Refine logic and comments
cmpatino c4f9a64
Improve clarity of buffer implementation
cmpatino 111b85e
Add validation for num_generations
cmpatino 022af62
Add clarifying comment to num_generations
cmpatino 33e0a82
Patch issue with ZeRO-3
cmpatino dbb6e70
Refactor context for ZeRO-3 + Liger
cmpatino 9da54b3
Simplify comments and code logic
cmpatino 1cec9ea
Merge pull request #1 from cmpatino/kd-buffer-fix
cmpatino 4435409
Add scripts to run GOLD
cmpatino ce41aba
Merge pull request #2 from cmpatino/kd-buffer-fix
cmpatino c0a857f
Merge branch 'kd-buffering' of github.com:cmpatino/trl into kd-buffering
cmpatino fa62472
Merge branch 'main' into kd-buffering
cmpatino 31161a0
Refactor to simplify logic
cmpatino da7ef50
Handle student versioning params
cmpatino e24e681
Add warning when dropping incomplete batches
cmpatino 8d31b7a
Add clarifying note in docs
cmpatino 1ef205b
Remove SLURM script used for testing
cmpatino 506afc1
Remove reference to wandb
cmpatino 7e9cb5e
Merge branch 'main' into kd-buffering
lewtun 98ec20c
Remove `_RepeatEachBatchDataLoader` to simplify codebase
cmpatino f89e77f
Merge branch 'kd-buffering' of github.com:cmpatino/trl into kd-buffering
cmpatino File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -51,6 +51,7 @@ | |
| """ | ||
|
|
||
| import logging | ||
| import os | ||
|
|
||
| from datasets import load_dataset | ||
| from transformers import AutoTokenizer, GenerationConfig | ||
|
|
@@ -78,6 +79,19 @@ | |
| ################ | ||
| # Model & Tokenizer | ||
| ################ | ||
| if training_args.student_model_revision is None: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do we need the
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. True. I'll support just |
||
| training_args.student_model_revision = model_args.model_revision | ||
| elif ( | ||
| model_args.model_revision is not None | ||
| and training_args.student_model_revision != model_args.model_revision | ||
| ): | ||
| raise ValueError( | ||
| "Conflicting revisions for student model: " | ||
| f"student_model_revision={training_args.student_model_revision!r} and " | ||
| f"model_revision={model_args.model_revision!r}. " | ||
| "Set only one revision, or set both to the same value." | ||
| ) | ||
|
|
||
| quantization_config = get_quantization_config(model_args) | ||
| model_kwargs = dict( | ||
| revision=training_args.student_model_revision, | ||
|
|
@@ -93,21 +107,21 @@ | |
| if training_args.teacher_tokenizer_name_or_path is None and training_args.use_uld_loss: | ||
| training_args.teacher_tokenizer_name_or_path = training_args.teacher_model_name_or_path | ||
| teacher_model_kwargs = dict( | ||
| revision=model_args.model_revision, | ||
| trust_remote_code=model_args.trust_remote_code, | ||
| attn_implementation=model_args.attn_implementation, | ||
| torch_dtype=model_args.dtype, | ||
| use_cache=True, | ||
| device_map=get_kbit_device_map() if quantization_config is not None else None, | ||
| quantization_config=quantization_config, | ||
| ) | ||
| if training_args.teacher_model_init_kwargs is not None: | ||
| teacher_model_kwargs.update(training_args.teacher_model_init_kwargs) | ||
| training_args.teacher_model_init_kwargs = teacher_model_kwargs | ||
|
|
||
| tokenizer = AutoTokenizer.from_pretrained( | ||
| model_args.model_name_or_path, | ||
| revision=model_args.model_revision, | ||
| trust_remote_code=model_args.trust_remote_code, | ||
| padding_side="left", | ||
| ) | ||
| if tokenizer.pad_token is None: | ||
| tokenizer.pad_token = tokenizer.eos_token | ||
|
|
@@ -120,7 +134,6 @@ | |
| ################ | ||
| # Training | ||
| ################ | ||
| # Handle eval dataset - check if test split exists, fallback to validation or None | ||
| eval_dataset = None | ||
| if training_args.eval_strategy != "no": | ||
| if script_args.dataset_test_split in dataset: | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
osis imported but not used in this script, which will fail linting / static checks in many setups. Remove the unused import, or use it where intended.