Skip to content

Commit 890725c

Browse files
authored
Updated ODM defaults (#648)
Signed-off-by: romitjain <[email protected]>
1 parent 5eeec0e commit 890725c

File tree

4 files changed

+14
-9
lines changed

4 files changed

+14
-9
lines changed

docs/advanced-data-preprocessing.md

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -147,11 +147,13 @@ Each data handler has:
147147
`odm` config has the following fields and is required when `datapreprocessor` `type` is `odm`.
148148

149149
`odm`:
150-
`update_interval` (optional, int, defaults to `1`): Multi-Armed Bandit (MAB) is used to learn from the training signals and then provide mixing probabilities across datasets. `update_interval` defines the frequency of updating the MAB with training signals in terms of step count.
151-
`sampling_interval` (optional, int, defaults to `1`): Defines the frequency of choosing a dataset to sample from through MAB. The value is provided in terms of sample count.
152-
`reward_type` (optional, str, defaults to `entropy`): Type of reward to be used to update MAB. Currently supported rewards are `train_loss`, `validation_loss`, `entropy`, `entropy3_varent1`, `entropy_last_token`, `gradnorm`. More details can be found [here](https://github.com/foundation-model-stack/fms-acceleration/tree/main/plugins/online-data-mixing#rewards).
153-
`gamma` (optional, int, defaults to `0.1`): MAB hyper-parameter which is similar to exploration factor.
154-
`eta` (optional, int, defaults to `0.1`): MAB hyper-parameter which is similar to learning rate.
150+
- `update_interval` (optional, int, defaults to `None`): Multi-Armed Bandit (MAB) is used to learn from the training signals and then provide mixing probabilities across datasets. `update_interval` defines the frequency of updating the MAB with training signals in terms of step count. If not provided, it defaults to `eval_steps`
151+
- `sampling_interval` (optional, int, defaults to `1`): Defines the frequency of choosing a dataset to sample from through MAB. The value is provided in terms of sample count.
152+
- `reward_type` (optional, str, defaults to `entropy`): Type of reward to be used to update MAB. Currently supported rewards are `train_loss`, `validation_loss`, `entropy`, `entropy3_varent1`, `entropy_last_token`, `gradnorm`. More details can be found [here](https://github.com/foundation-model-stack/fms-acceleration/tree/main/plugins/online-data-mixing#rewards).
153+
- `gamma` (optional, int, defaults to `0.1`): MAB hyper-parameter which is similar to exploration factor.
154+
- `eta` (optional, int, defaults to `0.3`): MAB hyper-parameter which is similar to learning rate.
155+
- `auto_categorize_input_column` (optional, str, defaults to `None`): If only a single dataset is provided, this field is required to determin the column name which should be used to categorize the data into psuedo categories
156+
- `auto_categorize_num_categories` (optional, int, defaults to `None`): Used in conjunction with the above field, this field specifies the number of psuedo categories to be assigned in the dataset
155157

156158
`datasets` (list):
157159
- `name` (optional, str): A unique identifier for the dataset.

tuning/config/acceleration_configs/odm.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,10 @@
2424
@dataclass
2525
class ODM:
2626
update_interval: int = None
27-
sampling_interval: int = None
28-
reward_type: str = None
27+
sampling_interval: int = 1
28+
reward_type: str = "entropy"
2929
gamma: float = 0.1
30-
eta: float = 0.1
30+
eta: float = 0.3
3131
resume_from_checkpoint: Union[bool, str] = False
3232
auto_categorize_input_column: str = None
3333
auto_categorize_num_categories: Optional[int] = None

tuning/data/setup_dataprocessor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -549,7 +549,7 @@ def setup_train_dataset_for_odm(
549549
)
550550

551551
auto_categorize_config = {}
552-
if hasattr(odm_config.odm, "auto_categorize_input_column"):
552+
if odm_config.odm.auto_categorize_input_column:
553553
auto_categorize_config = {
554554
"input_column": "input_ids",
555555
"num_categories": int(odm_config.odm.auto_categorize_num_categories),

tuning/sft_trainer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,9 @@ def train(
160160
"resume_from_checkpoint"
161161
] = resume_from_checkpoint
162162
odm_config = ODMConfig(odm=ODM(**_dataconfig.dataprocessor.odm))
163+
odm_config.odm.update_interval = (
164+
odm_config.odm.update_interval or train_args.eval_steps
165+
)
163166

164167
# Validate parameters
165168
if (not isinstance(model_args.model_name_or_path, str)) or (

0 commit comments

Comments
 (0)