-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathprocess.py
More file actions
79 lines (65 loc) · 2.92 KB
/
process.py
File metadata and controls
79 lines (65 loc) · 2.92 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import re
from typing import List, Union
from dragon_baseline import DragonBaseline
from transformers import EarlyStoppingCallback
class DragonSubmission(DragonBaseline):
def __init__(self, **kwargs):
# Example of how to adapt the DRAGON baseline to use a different model
"""
Adapt the DRAGON baseline to use the joeranbosma/dragon-roberta-base-mixed-domain model.
Note: when changing the model, update the Dockerfile to pre-download that model.
"""
super().__init__(**kwargs)
self.model_name_or_path = "joeranbosma/dragon-roberta-large-domain-specific"
# self.model_name_or_path = "joeranbosma/dragon-bert-base-mixed-domain"
self.per_device_train_batch_size = 8
self.gradient_accumulation_steps = 1
self.gradient_checkpointing = False
self.max_seq_length = 512
self.learning_rate = 1e-05
self.num_train_epochs = 20
self.model_kwargs.update({
"save_total_limit": 1,
# precision/math
"optim": "adamw_torch_fused",
"use_liger_kernel" : True,
"fp16" : True,
# logging/checkpointing
"logging_steps": 1000,
"save_strategy": "best",
"save_only_model": True,
"load_best_model_at_end": True,
"metric_for_best_model": "dragon",
"greater_is_better": True,
})
def custom_text_cleaning(self, text: Union[str, List[str]]) -> Union[str, List[str]]:
"""
Perform custom text cleaning on the input text.
Args:
text (Union[str, List[str]]): The input text to be cleaned. It can be a string or a list of strings.
Returns:
Union[str, List[str]]: The cleaned text. If the input is a string, the cleaned string is returned.
If the input is a list of strings, a list of cleaned strings is returned.
"""
if isinstance(text, str):
# Remove HTML tags and URLs:
text = re.sub(r"<.*?>", "", text)
text = re.sub(r"http\S+", "", text)
return text
else:
# If text is a list, apply the function to each element
return [self.custom_text_cleaning(t) for t in text]
def preprocess(self):
# Example of how to adapt the DRAGON baseline to use a different preprocessing function
super().preprocess()
# Uncomment the following lines to use the custom_text_cleaning function
# for df in [self.df_train, self.df_val, self.df_test]:
# df[self.task.input_name] = df[self.task.input_name].map(self.custom_text_cleaning)
if __name__ == "__main__":
mdl = DragonSubmission()
mdl.setup()
mdl.trainer.add_callback(EarlyStoppingCallback(early_stopping_patience=5, early_stopping_threshold=0.005))
mdl.train()
predictions = mdl.predict(df=mdl.df_test)
mdl.save(predictions)
mdl.verify_predictions()