33
44import numpy as np
55from datasets import ClassLabel , load_dataset
6-
76from evaluate import load
8- from transformers import (
9- AutoModelForSequenceClassification ,
10- AutoTokenizer ,
11- DataCollatorWithPadding ,
12- Trainer ,
13- TrainerCallback ,
14- TrainingArguments ,
15- set_seed ,
16- )
7+ from transformers import (AutoModelForSequenceClassification , AutoTokenizer ,
8+ DataCollatorWithPadding , Trainer , TrainerCallback ,
9+ TrainingArguments , set_seed )
1710
1811
1912def get_args ():
2013 parser = argparse .ArgumentParser ()
21- parser .add_argument ("--model_ckpt" , type = str , default = "microsoft/unixcoder-base-nine" )
14+ parser .add_argument (
15+ "--model_ckpt" , type = str , default = "microsoft/unixcoder-base-nine"
16+ )
2217 parser .add_argument ("--max_length" , type = int , default = 1024 )
2318 parser .add_argument ("--num_epochs" , type = int , default = 5 )
2419 parser .add_argument ("--batch_size" , type = int , default = 6 )
@@ -52,7 +47,9 @@ def __init__(self, trainer) -> None:
5247 def on_epoch_end (self , args , state , control , ** kwargs ):
5348 if control .should_evaluate :
5449 control_copy = deepcopy (control )
55- self ._trainer .evaluate (eval_dataset = self ._trainer .train_dataset , metric_key_prefix = "train" )
50+ self ._trainer .evaluate (
51+ eval_dataset = self ._trainer .train_dataset , metric_key_prefix = "train"
52+ )
5653 return control_copy
5754
5855
@@ -61,21 +58,28 @@ def main():
6158 set_seed (args .seed )
6259
6360 ds = load_dataset ("code_x_glue_cc_clone_detection_big_clone_bench" )
64- labels = ClassLabel (num_classes = 2 , names = [True , False ])
61+ labels = ClassLabel (num_classes = 2 , names = [True , False ])
6562 ds = ds .cast_column ("label" , labels )
6663
6764 print ("Loading tokenizer and model" )
6865 tokenizer = AutoTokenizer .from_pretrained (args .model_ckpt )
6966 tokenizer .pad_token = tokenizer .eos_token
70- model = AutoModelForSequenceClassification .from_pretrained (args .model_ckpt , num_labels = 2 )
67+ model = AutoModelForSequenceClassification .from_pretrained (
68+ args .model_ckpt , num_labels = 2
69+ )
7170 model .config .pad_token_id = model .config .eos_token_id
7271
7372 if args .freeze :
7473 for param in model .roberta .parameters ():
7574 param .requires_grad = False
7675
7776 def tokenize (example ):
78- inputs = tokenizer (example ["func1" ], example ["func2" ], truncation = True , max_length = args .max_length )
77+ inputs = tokenizer (
78+ example ["func1" ],
79+ example ["func2" ],
80+ truncation = True ,
81+ max_length = args .max_length ,
82+ )
7983 return {
8084 "input_ids" : inputs ["input_ids" ],
8185 "attention_mask" : inputs ["attention_mask" ],
@@ -121,10 +125,11 @@ def tokenize(example):
121125
122126 result = trainer .evaluate (eval_dataset = tokenized_datasets ["test" ])
123127 print (f"Evaluation accuracy on the test set: { result ['eval_accuracy' ]} " )
124-
128+
125129 # push the model to the Hugging Face hub
126130 if args .push_to_hub :
127131 model .push_to_hub (args .model_hub_name )
128132
133+
129134if __name__ == "__main__" :
130135 main ()
0 commit comments