-
Notifications
You must be signed in to change notification settings - Fork 247
[update] Updated RoPE Configuration for HF Models (transformers) w. backward-compatible support for vLLM #690 #703
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 35 commits
e7d9d44
38312c2
f4ec82d
45034fb
ac6d2fc
5dc16a0
8954231
a7105cf
3f6512a
4c39a7e
270e0f7
8c1dd19
8623973
9f8b08b
3c5884e
210609d
ee0259e
eb1c0f2
536e5ef
fd18186
32250c6
22bed1c
62228b9
eaedbf6
15ab029
cbd75bf
a34c01d
0b8a1b9
b62f97a
63aeb06
82dd877
bb9774e
a3963a3
52a8959
7ad5745
585c1fd
0ccb00a
f418a44
0d4131c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -660,15 +660,43 @@ def build_dataloader( | |
| return dataloader | ||
|
|
||
|
|
||
| def get_rope_scaling_config(trainer_cfg: DictConfig) -> dict[str, Any]: | ||
| if "rope_scaling" not in trainer_cfg: | ||
| return {} | ||
| if trainer_cfg.rope_scaling is None: | ||
| return None | ||
| return OmegaConf.to_container(trainer_cfg.rope_scaling) | ||
| def get_rope_parameters_config(trainer_cfg: DictConfig) -> dict[str, Any]: | ||
| rope_scaling = trainer_cfg.get("rope_scaling", None) | ||
| rope_theta = trainer_cfg.get("rope_theta", None) | ||
| has_old_config = rope_scaling is not None or rope_theta is not None | ||
|
|
||
| rope_parameters_new = trainer_cfg.get("rope_parameters", None) | ||
| has_new_config = rope_parameters_new is not None | ||
|
|
||
| if has_old_config and has_new_config: | ||
| logger.warning( | ||
| "Both old ('rope_scaling', 'rope_theta') and new ('rope_parameters') RoPE configs are provided. " | ||
| "Prioritizing the old config for backward compatibility. Please migrate to 'rope_parameters'." | ||
| ) | ||
|
|
||
| if has_old_config: | ||
| rope_parameters = {} | ||
| if rope_scaling is not None: | ||
| rope_scaling_dict = ( | ||
| OmegaConf.to_container(rope_scaling, resolve=True) | ||
| if isinstance(rope_scaling, DictConfig) | ||
| else rope_scaling | ||
| ) | ||
|
Comment on lines
+680
to
+684
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This conditional expression to create rope_scaling_dict = OmegaConf.to_container(rope_scaling, resolve=True) |
||
| if isinstance(rope_scaling_dict, dict): | ||
| rope_parameters.update(rope_scaling_dict) | ||
| else: | ||
| logger.warning(f"Ignoring 'rope_scaling' as it is not a dictionary. Found: {rope_scaling_dict}") | ||
| if rope_theta is not None: | ||
| rope_parameters["rope_theta"] = rope_theta | ||
| return rope_parameters | ||
|
|
||
| elif has_new_config: | ||
| new_params = OmegaConf.to_container(rope_parameters_new, resolve=True) | ||
| if isinstance(new_params, dict): | ||
| return new_params | ||
| if new_params is not None: | ||
| logger.warning(f"Ignoring 'rope_parameters' as it is not a dictionary. Found: {new_params}") | ||
| return {} | ||
|
Comment on lines
+697
to
+699
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The logger.warning(f"Ignoring 'rope_parameters' as it is not a dictionary. Found: {new_params}")
return {} |
||
|
|
||
| def get_rope_theta_config(trainer_cfg: DictConfig) -> int | None: | ||
| if "rope_theta" not in trainer_cfg: | ||
| return None | ||
| return trainer_cfg.rope_theta | ||
| else: | ||
| return {} | ||
Uh oh!
There was an error while loading. Please reload this page.