Skip to content

Commit 7205eb2

Browse files
authored
Fix: interface change for sampling weights (#156)
* fix: refactor sampling weight Signed-off-by: Mehant Kammakomati <[email protected]> * fix: refactor sampling weight Signed-off-by: Mehant Kammakomati <[email protected]> * fix: refactor sampling weight Signed-off-by: Mehant Kammakomati <[email protected]> --------- Signed-off-by: Mehant Kammakomati <[email protected]>
1 parent 30c2c55 commit 7205eb2

File tree

2 files changed

+48
-28
lines changed

2 files changed

+48
-28
lines changed

plugins/online-data-mixing/src/fms_acceleration_odm/odm/dataset.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Standard
22
from logging import getLogger
3-
from typing import List, Optional
3+
from typing import Optional
44
import json
55
import math
66
import os
@@ -27,7 +27,7 @@ def __init__(
2727
collators_dict: dict,
2828
eval_dataset_dict: DatasetDict,
2929
eval_collators_dict: dict,
30-
sampling_weights: Optional[List[float]] = None,
30+
sampling_weights: Optional[dict] = None,
3131
gamma: float = 0.1,
3232
eta: float = 0.3,
3333
sampling_interval: int = 1,
@@ -51,7 +51,7 @@ def __init__(
5151
eval datasets.
5252
eval_collators_dict (dict): collator corresponding to each dataset
5353
used while constructing torch dataloader.
54-
sampling_weights (Optional[List[float]], optional): Initial
54+
sampling_weights (Optional[dict], optional): Initial
5555
set of sampling weights to start with. Defaults to equal weightage.
5656
gamma (float, optional): MAB hyperparameter. Defaults to 0.1.
5757
eta (float, optional): MAB hyperparameter. Defaults to 0.3.
@@ -123,9 +123,13 @@ def __init__(
123123
# are equally important. Weights based on the size of the datasets
124124
# and other such heuristics should be computed outside and passed
125125
# through sampling_weights while initializing this class.
126-
if sampling_weights is None:
127-
sampling_weights = [1] * self.total_categories
128-
self.sampling_weights = torch.tensor(sampling_weights, dtype=torch.float64)
126+
if not sampling_weights:
127+
self.sampling_weights = [1] * self.total_categories
128+
else:
129+
self.sampling_weights = []
130+
for cat in self.category_list:
131+
self.sampling_weights.append(sampling_weights[cat])
132+
self.sampling_weights = torch.tensor(self.sampling_weights, dtype=torch.float64)
129133
self.sampling_ratio = []
130134
self._update_sampling_ratio(self.sampling_weights)
131135

plugins/online-data-mixing/tests/test_online_data.py

Lines changed: 38 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,43 @@
1313
# limitations under the License.
1414

1515
# Third Party
16+
from torch.utils.data import IterableDataset
17+
1618
# pylint: disable=import-error
1719
import pytest
1820
import torch
1921

2022
# First Party
2123
from fms_acceleration_odm import OnlineMixingDataset, Reward
2224

25+
26+
class SampleDataset(IterableDataset):
27+
def __init__(self, seq_length, vocab_size):
28+
self.seq_length = seq_length
29+
self.vocab_size = vocab_size
30+
31+
def __len__(self):
32+
pass
33+
34+
def __iter__(self):
35+
return self
36+
37+
def __next__(self):
38+
input_ids = torch.rand(self.seq_length)
39+
return {
40+
"input_ids": input_ids,
41+
"attention_mask": torch.ones(self.seq_length),
42+
"labels": input_ids,
43+
}
44+
45+
46+
def get_dataset(seq_len, vocab_size):
47+
return SampleDataset(seq_length=seq_len, vocab_size=vocab_size)
48+
49+
2350
PARAMETERS = [
2451
(
25-
[1, 100, 2],
52+
{"data_1": 1, "data_2": 100, "data_3": 2},
2653
[[1, 100, 1], [1, 200, 1], [1, 100, 1], [1, 1, 1000], [1, 1, 2000]],
2754
5,
2855
[1, 1, 1, 2, 2],
@@ -41,29 +68,18 @@ def test_online_data_mix_learning(
4168
batch_size = 100
4269
seq_length = 6
4370
vocab_size = 50
44-
input_ids = (
45-
torch.arange(batch_size * seq_length).reshape(batch_size, seq_length)
46-
% vocab_size
47-
)
48-
attention_mask = torch.tensor(
49-
[[1, 1, 1, 1, 0, 0], [1, 1, 1, 0, 0, 0], [1, 1, 1, 1, 1, 1]]
50-
)
51-
labels = input_ids
52-
train_data = {
53-
"input_ids": input_ids,
54-
"labels": labels,
55-
"attention_mask": attention_mask,
56-
}
57-
eval_data = {
58-
"input_ids": input_ids,
59-
"labels": labels,
60-
"attention_mask": attention_mask,
71+
72+
train_data_dict = {
73+
"data_1": get_dataset(seq_len=seq_length, vocab_size=vocab_size),
74+
"data_2": get_dataset(seq_len=seq_length, vocab_size=vocab_size),
75+
"data_3": get_dataset(seq_len=seq_length, vocab_size=vocab_size),
6176
}
77+
collators_dict = {"data_1": None, "data_2": None, "data_3": None}
6278
dataset = OnlineMixingDataset(
63-
train_data,
64-
None,
65-
eval_data,
66-
None,
79+
train_data_dict,
80+
collators_dict,
81+
train_data_dict,
82+
collators_dict,
6783
sampling_weights,
6884
0.1,
6985
0.3,

0 commit comments

Comments
 (0)