Skip to content

Commit a0f74da

Browse files
committed
fix: adding hints to the arguments and returns
Signed-off-by: Omobayode Fagbohungbe <[email protected]>
1 parent cfb9ea9 commit a0f74da

File tree

2 files changed

+71
-45
lines changed

2 files changed

+71
-45
lines changed

fms_mo/utils/calib_data.py

Lines changed: 67 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,15 @@
3131
import torch
3232

3333

34-
def return_tokenized_samples(nsamples, trainenc, seqlen, sequential=False) -> dict[str, torch.int]:
34+
def return_tokenized_samples(
35+
nsamples: int, trainenc: list, seqlen: int, sequential: bool = False
36+
) -> dict:
3537
"""Randomly crop nsamples sequence from trainenc, each with the length of seqlen.
3638
see below functions, e.g. get_wikitext2() for more details.
3739
"""
3840
traindataset = {
39-
"input_ids": torch.zeros(size = (nsamples, seqlen), dtype = torch.int),
40-
"attention_mask": torch.zeros(size = (nsamples, seqlen), dtype = torch.int)
41+
"input_ids": torch.zeros(size=(nsamples, seqlen), dtype=torch.int),
42+
"attention_mask": torch.zeros(size=(nsamples, seqlen), dtype=torch.int),
4143
}
4244
i = 0
4345

@@ -57,8 +59,13 @@ def return_tokenized_samples(nsamples, trainenc, seqlen, sequential=False) -> di
5759

5860

5961
def get_wikitext2(
60-
nsamples, seed, seqlen, tokenizer, sequential=False, gptq_style=False
61-
):
62+
nsamples: int,
63+
seed: int,
64+
seqlen: int,
65+
tokenizer: str,
66+
sequential: bool = False,
67+
gptq_style: bool = False,
68+
) -> tuple[dict, dict]:
6269
"""Prepare data for GPTQ using wikitext2 dataset.
6370
6471
Args:
@@ -87,14 +94,21 @@ def get_wikitext2(
8794
nsamples, trainenc, seqlen, sequential=sequential
8895
)
8996
testenc = {
90-
"input_ids": testenc["input_ids"],
91-
"attention_mask": testenc["attention_mask"]
97+
"input_ids": testenc["input_ids"],
98+
"attention_mask": testenc["attention_mask"],
9299
}
93100

94101
return traindataset, testenc
95102

96103

97-
def get_ptb(nsamples, seed, seqlen, tokenizer, sequential=False, gptq_style=False):
104+
def get_ptb(
105+
nsamples: int,
106+
seed: int,
107+
seqlen: int,
108+
tokenizer: str,
109+
sequential: bool = False,
110+
gptq_style: bool = False,
111+
) -> tuple[dict, dict]:
98112
"""Prepare data for GPTQ using PTB dataset.
99113
100114
Args:
@@ -117,18 +131,20 @@ def get_ptb(nsamples, seed, seqlen, tokenizer, sequential=False, gptq_style=Fals
117131
traindata = "\n\n".join(traindata["sentence"])
118132

119133
trainenc = tokenizer(traindata)
120-
testenc = tokenizer("\n\n".join(valdata["sentence"]),return_tensors="pt")
134+
testenc = tokenizer("\n\n".join(valdata["sentence"]), return_tensors="pt")
121135

122136
traindataset = return_tokenized_samples(nsamples, trainenc, seqlen, sequential)
123137
testenc = {
124-
"input_ids": testenc["input_ids"],
125-
"attention_mask": testenc["attention_mask"]
138+
"input_ids": testenc["input_ids"],
139+
"attention_mask": testenc["attention_mask"],
126140
}
127141

128142
return traindataset, testenc
129143

130144

131-
def get_c4_train(nsamples, seed, seqlen, tokenizer, sequential=False):
145+
def get_c4_train(
146+
nsamples: int, seed: int, seqlen: int, tokenizer: str, sequential: bool = False
147+
) -> tuple[dict, dict]:
132148
"""Prepare data for GPTQ using C4 dataset.
133149
134150
Args:
@@ -153,11 +169,11 @@ def get_c4_train(nsamples, seed, seqlen, tokenizer, sequential=False):
153169
split="validation",
154170
)
155171

156-
testenc = tokenizer("\n\n".join(valdata["text"]),return_tensors="pt")
172+
testenc = tokenizer("\n\n".join(valdata["text"]), return_tensors="pt")
157173

158-
trainloader ={
159-
"input_ids": torch.zeros(size = (nsamples, seqlen), dtype = torch.int),
160-
"attention_mask": torch.zeros(size = (nsamples, seqlen), dtype = torch.int)
174+
trainloader = {
175+
"input_ids": torch.zeros(size=(nsamples, seqlen), dtype=torch.int),
176+
"attention_mask": torch.zeros(size=(nsamples, seqlen), dtype=torch.int),
161177
}
162178
for k in range(nsamples):
163179
while True:
@@ -182,7 +198,7 @@ def get_c4_train(nsamples, seed, seqlen, tokenizer, sequential=False):
182198
return trainloader, testdataset
183199

184200

185-
def get_c4_new(nsamples, seed, seqlen, tokenizer):
201+
def get_c4_new(nsamples: int, seed: int, seqlen: int, tokenizer: str):
186202
"""Prepare data for GPTQ using C4 dataset.
187203
188204
Args:
@@ -227,8 +243,8 @@ def get_c4_new(nsamples, seed, seqlen, tokenizer):
227243

228244

229245
def get_self_instruct_starcoder(
230-
nsamples, seed, seqlen, tokenizer, split_name="curated"
231-
): # pylint: disable=unused-argument
246+
nsamples: int, seed: int, seqlen: int, tokenizer: str, split_name: str = "curated"
247+
) -> tuple[dict, dict]: # pylint: disable=unused-argument
232248
"""Prepare data for GPTQ using starcoder dataset.
233249
234250
Args:
@@ -244,8 +260,8 @@ def get_self_instruct_starcoder(
244260

245261
eval_dataset = tokenizer(" ".join(cr_dataset[:]["output"]), return_tensors="pt")
246262
eval_dataset = {
247-
"input_ids": eval_dataset["input_ids"],
248-
"attention_mask": eval_dataset["attention_mask"]
263+
"input_ids": eval_dataset["input_ids"],
264+
"attention_mask": eval_dataset["attention_mask"],
249265
}
250266

251267
cr_dataset.shuffle(seed)
@@ -255,13 +271,15 @@ def get_self_instruct_starcoder(
255271
tokenizer.pad_token = tokenizer.eos_token
256272

257273
trainloader = {
258-
"input_ids": torch.zeros(size = (nsamples,seqlen), dtype=torch.int),
259-
"attention_mask": torch.zeros(size = (nsamples,seqlen), dtype=torch.int)
274+
"input_ids": torch.zeros(size=(nsamples, seqlen), dtype=torch.int),
275+
"attention_mask": torch.zeros(size=(nsamples, seqlen), dtype=torch.int),
260276
}
261277
for k in range(nsamples):
262278
tokenized = tokenizer(
263-
cr_dataset[k]["output"], return_tensors="pt",
264-
padding="max_length", max_length = seqlen
279+
cr_dataset[k]["output"],
280+
return_tensors="pt",
281+
padding="max_length",
282+
max_length=seqlen,
265283
)
266284
trainloader["input_ids"][k] = tokenized.input_ids.squeeze(0)
267285
trainloader["attention_mask"][k] = tokenized.attention_mask.squeeze(0)
@@ -270,8 +288,13 @@ def get_self_instruct_starcoder(
270288

271289

272290
def get_cobol_java_supervised(
273-
nsamples, seed, seqlen=8192, tokenizer = "", split_name="both", file_path=None
274-
):
291+
nsamples: int,
292+
seed: int,
293+
seqlen: int = 8192,
294+
tokenizer: str = "",
295+
split_name: str = "both",
296+
file_path: str = None,
297+
) -> tuple[dict, dict]:
275298
"""Prepare data for GPTQ using cobol/java dataset.
276299
277300
Args:
@@ -294,17 +317,17 @@ def get_cobol_java_supervised(
294317

295318
eval_dataset = tokenizer(data_dict_array["content"], return_tensors="pt")
296319
eval_dataset = {
297-
"input_ids": eval_dataset["input_ids"],
298-
"attention_mask": eval_dataset["attention_mask"]
320+
"input_ids": eval_dataset["input_ids"],
321+
"attention_mask": eval_dataset["attention_mask"],
299322
}
300323

301324
random.shuffle(data_dict_array)
302325

303326
nsamples = min(nsamples, len(data_dict_array))
304327

305328
trainloader = {
306-
"input_ids": torch.zeros(size = (nsamples,seqlen), dtype=torch.int),
307-
"attention_mask": torch.zeros(size = (nsamples,seqlen), dtype=torch.int)
329+
"input_ids": torch.zeros(size=(nsamples, seqlen), dtype=torch.int),
330+
"attention_mask": torch.zeros(size=(nsamples, seqlen), dtype=torch.int),
308331
}
309332
added_ex = 0
310333

@@ -343,15 +366,15 @@ def get_cobol_java_supervised(
343366

344367

345368
def get_tokenized_data(
346-
name,
347-
nsamples=128,
348-
seqlen=2048,
349-
tokenizer="",
350-
seed=0,
351-
gptq_style=False,
352-
path_to_save=None,
353-
field_name=None,
354-
):
369+
name: str,
370+
nsamples: int = 128,
371+
seqlen: int = 2048,
372+
tokenizer: str = "",
373+
seed: int = 0,
374+
gptq_style: bool = False,
375+
path_to_save: str = None,
376+
field_name: str = None,
377+
) -> tuple[dict, dict]:
355378
"""Convenient function to get data. Default to get_wikitext2."""
356379

357380
# Option 1: User provide a dataset from disk, only need to tokenize and format it.
@@ -422,7 +445,10 @@ def get_tokenized_data(
422445
)
423446
elif "java" in name:
424447
traindataset, testdataset = get_cobol_java_supervised(
425-
nsamples, seed, seqlen, tokenizer,
448+
nsamples,
449+
seed,
450+
seqlen,
451+
tokenizer,
426452
)
427453
else:
428454
raise NotImplementedError(

fms_mo/utils/eval_utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -116,9 +116,9 @@ def eval_llm_1GPU(qcfg, model, test_dataset, pre_cache_func=None, **kwargs): #
116116

117117
# Shift so that tokens < n predict n
118118
shift_logits = lm_logits[:, :-1, :].contiguous().float()
119-
shift_labels = test_dataset["input_ids"][:, (i * seq_len) : ((i + 1) * seq_len)][
120-
:, 1:
121-
].to(dev)
119+
shift_labels = test_dataset["input_ids"][
120+
:, (i * seq_len) : ((i + 1) * seq_len)
121+
][:, 1:].to(dev)
122122
loss_fct = nn.CrossEntropyLoss()
123123
loss = loss_fct(
124124
shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
@@ -144,7 +144,7 @@ def __init__(self, dataset, device, n_samples=160):
144144
self.dataset = dataset
145145
self.device = device
146146
# loading tokenized dataset.
147-
self.dataset = dataset['input_ids'].to(device)
147+
self.dataset = dataset["input_ids"].to(device)
148148
self.n_samples = n_samples
149149

150150
@torch.no_grad()

0 commit comments

Comments
 (0)