3131import torch
3232
3333
34- def return_tokenized_samples (nsamples , trainenc , seqlen , sequential = False ) -> dict [str , torch .int ]:
34+ def return_tokenized_samples (
35+ nsamples : int , trainenc : list , seqlen : int , sequential : bool = False
36+ ) -> dict :
3537 """Randomly crop nsamples sequence from trainenc, each with the length of seqlen.
3638 see below functions, e.g. get_wikitext2() for more details.
3739 """
3840 traindataset = {
39- "input_ids" : torch .zeros (size = (nsamples , seqlen ), dtype = torch .int ),
40- "attention_mask" : torch .zeros (size = (nsamples , seqlen ), dtype = torch .int )
41+ "input_ids" : torch .zeros (size = (nsamples , seqlen ), dtype = torch .int ),
42+ "attention_mask" : torch .zeros (size = (nsamples , seqlen ), dtype = torch .int ),
4143 }
4244 i = 0
4345
@@ -57,8 +59,13 @@ def return_tokenized_samples(nsamples, trainenc, seqlen, sequential=False) -> di
5759
5860
5961def get_wikitext2 (
60- nsamples , seed , seqlen , tokenizer , sequential = False , gptq_style = False
61- ):
62+ nsamples : int ,
63+ seed : int ,
64+ seqlen : int ,
65+ tokenizer : str ,
66+ sequential : bool = False ,
67+ gptq_style : bool = False ,
68+ ) -> tuple [dict , dict ]:
6269 """Prepare data for GPTQ using wikitext2 dataset.
6370
6471 Args:
@@ -87,14 +94,21 @@ def get_wikitext2(
8794 nsamples , trainenc , seqlen , sequential = sequential
8895 )
8996 testenc = {
90- "input_ids" : testenc ["input_ids" ],
91- "attention_mask" : testenc ["attention_mask" ]
97+ "input_ids" : testenc ["input_ids" ],
98+ "attention_mask" : testenc ["attention_mask" ],
9299 }
93100
94101 return traindataset , testenc
95102
96103
97- def get_ptb (nsamples , seed , seqlen , tokenizer , sequential = False , gptq_style = False ):
104+ def get_ptb (
105+ nsamples : int ,
106+ seed : int ,
107+ seqlen : int ,
108+ tokenizer : str ,
109+ sequential : bool = False ,
110+ gptq_style : bool = False ,
111+ ) -> tuple [dict , dict ]:
98112 """Prepare data for GPTQ using PTB dataset.
99113
100114 Args:
@@ -117,18 +131,20 @@ def get_ptb(nsamples, seed, seqlen, tokenizer, sequential=False, gptq_style=Fals
117131 traindata = "\n \n " .join (traindata ["sentence" ])
118132
119133 trainenc = tokenizer (traindata )
120- testenc = tokenizer ("\n \n " .join (valdata ["sentence" ]),return_tensors = "pt" )
134+ testenc = tokenizer ("\n \n " .join (valdata ["sentence" ]), return_tensors = "pt" )
121135
122136 traindataset = return_tokenized_samples (nsamples , trainenc , seqlen , sequential )
123137 testenc = {
124- "input_ids" : testenc ["input_ids" ],
125- "attention_mask" : testenc ["attention_mask" ]
138+ "input_ids" : testenc ["input_ids" ],
139+ "attention_mask" : testenc ["attention_mask" ],
126140 }
127141
128142 return traindataset , testenc
129143
130144
131- def get_c4_train (nsamples , seed , seqlen , tokenizer , sequential = False ):
145+ def get_c4_train (
146+ nsamples : int , seed : int , seqlen : int , tokenizer : str , sequential : bool = False
147+ ) -> tuple [dict , dict ]:
132148 """Prepare data for GPTQ using C4 dataset.
133149
134150 Args:
@@ -153,11 +169,11 @@ def get_c4_train(nsamples, seed, seqlen, tokenizer, sequential=False):
153169 split = "validation" ,
154170 )
155171
156- testenc = tokenizer ("\n \n " .join (valdata ["text" ]),return_tensors = "pt" )
172+ testenc = tokenizer ("\n \n " .join (valdata ["text" ]), return_tensors = "pt" )
157173
158- trainloader = {
159- "input_ids" : torch .zeros (size = (nsamples , seqlen ), dtype = torch .int ),
160- "attention_mask" : torch .zeros (size = (nsamples , seqlen ), dtype = torch .int )
174+ trainloader = {
175+ "input_ids" : torch .zeros (size = (nsamples , seqlen ), dtype = torch .int ),
176+ "attention_mask" : torch .zeros (size = (nsamples , seqlen ), dtype = torch .int ),
161177 }
162178 for k in range (nsamples ):
163179 while True :
@@ -182,7 +198,7 @@ def get_c4_train(nsamples, seed, seqlen, tokenizer, sequential=False):
182198 return trainloader , testdataset
183199
184200
185- def get_c4_new (nsamples , seed , seqlen , tokenizer ):
201+ def get_c4_new (nsamples : int , seed : int , seqlen : int , tokenizer : str ):
186202 """Prepare data for GPTQ using C4 dataset.
187203
188204 Args:
@@ -227,8 +243,8 @@ def get_c4_new(nsamples, seed, seqlen, tokenizer):
227243
228244
229245def get_self_instruct_starcoder (
230- nsamples , seed , seqlen , tokenizer , split_name = "curated"
231- ): # pylint: disable=unused-argument
246+ nsamples : int , seed : int , seqlen : int , tokenizer : str , split_name : str = "curated"
247+ ) -> tuple [ dict , dict ] : # pylint: disable=unused-argument
232248 """Prepare data for GPTQ using starcoder dataset.
233249
234250 Args:
@@ -244,8 +260,8 @@ def get_self_instruct_starcoder(
244260
245261 eval_dataset = tokenizer (" " .join (cr_dataset [:]["output" ]), return_tensors = "pt" )
246262 eval_dataset = {
247- "input_ids" : eval_dataset ["input_ids" ],
248- "attention_mask" : eval_dataset ["attention_mask" ]
263+ "input_ids" : eval_dataset ["input_ids" ],
264+ "attention_mask" : eval_dataset ["attention_mask" ],
249265 }
250266
251267 cr_dataset .shuffle (seed )
@@ -255,13 +271,15 @@ def get_self_instruct_starcoder(
255271 tokenizer .pad_token = tokenizer .eos_token
256272
257273 trainloader = {
258- "input_ids" : torch .zeros (size = (nsamples ,seqlen ), dtype = torch .int ),
259- "attention_mask" : torch .zeros (size = (nsamples ,seqlen ), dtype = torch .int )
274+ "input_ids" : torch .zeros (size = (nsamples , seqlen ), dtype = torch .int ),
275+ "attention_mask" : torch .zeros (size = (nsamples , seqlen ), dtype = torch .int ),
260276 }
261277 for k in range (nsamples ):
262278 tokenized = tokenizer (
263- cr_dataset [k ]["output" ], return_tensors = "pt" ,
264- padding = "max_length" , max_length = seqlen
279+ cr_dataset [k ]["output" ],
280+ return_tensors = "pt" ,
281+ padding = "max_length" ,
282+ max_length = seqlen ,
265283 )
266284 trainloader ["input_ids" ][k ] = tokenized .input_ids .squeeze (0 )
267285 trainloader ["attention_mask" ][k ] = tokenized .attention_mask .squeeze (0 )
@@ -270,8 +288,13 @@ def get_self_instruct_starcoder(
270288
271289
272290def get_cobol_java_supervised (
273- nsamples , seed , seqlen = 8192 , tokenizer = "" , split_name = "both" , file_path = None
274- ):
291+ nsamples : int ,
292+ seed : int ,
293+ seqlen : int = 8192 ,
294+ tokenizer : str = "" ,
295+ split_name : str = "both" ,
296+ file_path : str = None ,
297+ ) -> tuple [dict , dict ]:
275298 """Prepare data for GPTQ using cobol/java dataset.
276299
277300 Args:
@@ -294,17 +317,17 @@ def get_cobol_java_supervised(
294317
295318 eval_dataset = tokenizer (data_dict_array ["content" ], return_tensors = "pt" )
296319 eval_dataset = {
297- "input_ids" : eval_dataset ["input_ids" ],
298- "attention_mask" : eval_dataset ["attention_mask" ]
320+ "input_ids" : eval_dataset ["input_ids" ],
321+ "attention_mask" : eval_dataset ["attention_mask" ],
299322 }
300323
301324 random .shuffle (data_dict_array )
302325
303326 nsamples = min (nsamples , len (data_dict_array ))
304327
305328 trainloader = {
306- "input_ids" : torch .zeros (size = (nsamples ,seqlen ), dtype = torch .int ),
307- "attention_mask" : torch .zeros (size = (nsamples ,seqlen ), dtype = torch .int )
329+ "input_ids" : torch .zeros (size = (nsamples , seqlen ), dtype = torch .int ),
330+ "attention_mask" : torch .zeros (size = (nsamples , seqlen ), dtype = torch .int ),
308331 }
309332 added_ex = 0
310333
@@ -343,15 +366,15 @@ def get_cobol_java_supervised(
343366
344367
345368def get_tokenized_data (
346- name ,
347- nsamples = 128 ,
348- seqlen = 2048 ,
349- tokenizer = "" ,
350- seed = 0 ,
351- gptq_style = False ,
352- path_to_save = None ,
353- field_name = None ,
354- ):
369+ name : str ,
370+ nsamples : int = 128 ,
371+ seqlen : int = 2048 ,
372+ tokenizer : str = "" ,
373+ seed : int = 0 ,
374+ gptq_style : bool = False ,
375+ path_to_save : str = None ,
376+ field_name : str = None ,
377+ ) -> tuple [ dict , dict ] :
355378 """Convenient function to get data. Default to get_wikitext2."""
356379
357380 # Option 1: User provide a dataset from disk, only need to tokenize and format it.
@@ -422,7 +445,10 @@ def get_tokenized_data(
422445 )
423446 elif "java" in name :
424447 traindataset , testdataset = get_cobol_java_supervised (
425- nsamples , seed , seqlen , tokenizer ,
448+ nsamples ,
449+ seed ,
450+ seqlen ,
451+ tokenizer ,
426452 )
427453 else :
428454 raise NotImplementedError (
0 commit comments