Skip to content

Commit 18287d3

Browse files
authored
[UTC] Support single-label and multiple-label classification (#5083)
1 parent 35be940 commit 18287d3

File tree

6 files changed

+67
-15
lines changed

6 files changed

+67
-15
lines changed

applications/zero_shot_text_classification/README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,7 @@ python -u -m paddle.distributed.launch --gpus "0,1" run_train.py \
153153
该示例代码中由于设置了参数 `--do_eval`,因此在训练完会自动进行评估。
154154

155155
可配置参数说明:
156+
* `single_label`: 每条样本是否只预测一个标签。默认为`False`,表示多标签分类。
156157
* `device`: 训练设备,可选择 'cpu'、'gpu' 其中的一种;默认为 GPU 训练。
157158
* `logging_steps`: 训练过程中日志打印的间隔 steps 数,默认10。
158159
* `save_steps`: 训练过程中保存模型 checkpoint 的间隔 steps 数,默认100。
@@ -199,6 +200,7 @@ python run_eval.py \
199200
- `test_path`: 进行评估的测试集文件。
200201
- `per_device_eval_batch_size`: 批处理大小,请结合机器情况进行调整,默认为16。
201202
- `max_seq_len`: 文本最大切分长度,输入超过最大长度时会对输入文本进行自动切分,默认为512。
203+
- `single_label`: 每条样本是否只预测一个标签。默认为`False`,表示多标签分类。
202204

203205
<a name="定制模型一键预测"></a>
204206

applications/zero_shot_text_classification/label_studio.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -54,24 +54,28 @@ def convert_utc_examples(self, raw_examples):
5454
utc_examples = []
5555
for example in raw_examples:
5656
raw_text = example["data"]["text"].split(self.text_separator)
57-
raw_label = example["annotations"][0]["result"][0]["value"]["choices"][0]
5857
if len(raw_text) < 1:
5958
continue
6059
elif len(raw_text) == 1:
6160
raw_text.append("")
6261
elif len(raw_text) > 2:
6362
raw_text = ["".join(raw_text[:-1]), raw_text[-1]]
64-
if raw_label not in self.options:
65-
raise ValueError(
66-
f"Label `{raw_label}` not found in label candidates `options`. Please recheck the data."
67-
)
63+
64+
label_list = []
65+
for raw_label in example["annotations"][0]["result"][0]["value"]["choices"]:
66+
if raw_label not in self.options:
67+
raise ValueError(
68+
f"Label `{raw_label}` not found in label candidates `options`. Please recheck the data."
69+
)
70+
label_list.append(np.where(np.array(self.options) == raw_label)[0].tolist()[0])
71+
6872
utc_examples.append(
6973
{
7074
"text_a": raw_text[0],
7175
"text_b": raw_text[1],
7276
"question": "",
7377
"choices": self.options,
74-
"labels": np.where(np.array(self.options) == raw_label)[0].tolist()[0],
78+
"labels": label_list,
7579
}
7680
)
7781
return utc_examples

applications/zero_shot_text_classification/label_studio_text.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,12 @@ label-studio start
8383

8484
项目创建后,可在Setting/Labeling Interface中继续配置标签,详见[项目创建](#label)
8585

86+
默认模式为单标签多分类数据标注。对于多标签多分类数据标注,需要将`choice`的值由`single`改为`multiple`
87+
88+
<div align="center">
89+
<img src=https://user-images.githubusercontent.com/25607475/222630045-8d6eebf7-572f-43d2-b7a1-24bf21a47fad.png />
90+
</div>
91+
8692
<a name="24"></a>
8793

8894
#### 2.4 任务标注

applications/zero_shot_text_classification/label_studio_text_en.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,12 @@ You can continue to import local txt format data after project creation. See mor
8484

8585
After project creation, you can add/delete labels in Setting/Labeling Interface just as in [Project Creation](#label)
8686

87+
LabelStudio supports single-label data annotation by default. Modify the value of `choice` as `multiple` in the `code` tab when multiple-label annotation is required.
88+
89+
<div align="center">
90+
<img src=https://user-images.githubusercontent.com/25607475/222630045-8d6eebf7-572f-43d2-b7a1-24bf21a47fad.png />
91+
</div>
92+
8793
<a name="24"></a>
8894

8995
#### 2.4 Task annotation

applications/zero_shot_text_classification/run_eval.py

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from dataclasses import dataclass, field
1818

1919
import paddle
20+
from paddle.metric import Accuracy
2021
from sklearn.metrics import f1_score
2122
from utils import UTCLoss, read_local_dataset
2223

@@ -35,6 +36,7 @@
3536
class DataArguments:
3637
test_path: str = field(default="./data/test.txt", metadata={"help": "Test dataset file name."})
3738
threshold: float = field(default=0.5, metadata={"help": "The threshold to produce predictions."})
39+
single_label: str = field(default=False, metadata={"help": "Predict exactly one label per sample."})
3840

3941

4042
@dataclass
@@ -71,6 +73,18 @@ def main():
7173
prompt_model.set_state_dict(model_state)
7274

7375
# Define the metric function.
76+
def compute_metrics_single_label(eval_preds):
77+
labels = paddle.to_tensor(eval_preds.label_ids, dtype="int64")
78+
preds = paddle.to_tensor(eval_preds.predictions)
79+
preds = paddle.nn.functional.softmax(preds, axis=-1)
80+
labels = paddle.argmax(labels, axis=-1)
81+
print(preds, labels)
82+
metric = Accuracy()
83+
correct = metric.compute(preds, labels)
84+
metric.update(correct)
85+
acc = metric.accumulate()
86+
return {"accuracy": acc}
87+
7488
def compute_metrics(eval_preds):
7589
labels = paddle.to_tensor(eval_preds.label_ids, dtype="int64")
7690
preds = paddle.to_tensor(eval_preds.predictions)
@@ -92,7 +106,7 @@ def compute_metrics(eval_preds):
92106
train_dataset=None,
93107
eval_dataset=None,
94108
callbacks=None,
95-
compute_metrics=compute_metrics,
109+
compute_metrics=compute_metrics_single_label if data_args.single_label else compute_metrics,
96110
)
97111

98112
if data_args.test_path is not None:
@@ -102,12 +116,20 @@ def compute_metrics(eval_preds):
102116
json.dump(test_ret.metrics, fp)
103117

104118
with open(os.path.join(training_args.output_dir, "test_predictions.json"), "w", encoding="utf-8") as fp:
105-
preds = paddle.nn.functional.sigmoid(paddle.to_tensor(test_ret.predictions))
106-
for index, pred in enumerate(preds):
107-
result = {"id": index}
108-
result["labels"] = paddle.where(pred > data_args.threshold)[0].tolist()
109-
result["probs"] = pred[pred > data_args.threshold].tolist()
110-
fp.write(json.dumps(result, ensure_ascii=False) + "\n")
119+
if data_args.single_label:
120+
preds = paddle.nn.functional.softmax(paddle.to_tensor(test_ret.predictions), axis=-1)
121+
for index, pred in enumerate(preds):
122+
result = {"id": index}
123+
result["labels"] = paddle.argmax(pred).item()
124+
result["probs"] = pred[result["labels"]].item()
125+
fp.write(json.dumps(result, ensure_ascii=False) + "\n")
126+
else:
127+
preds = paddle.nn.functional.sigmoid(paddle.to_tensor(test_ret.predictions))
128+
for index, pred in enumerate(preds):
129+
result = {"id": index}
130+
result["labels"] = paddle.where(pred > data_args.threshold)[0].tolist()
131+
result["probs"] = pred[pred > data_args.threshold].tolist()
132+
fp.write(json.dumps(result, ensure_ascii=False) + "\n")
111133

112134

113135
if __name__ == "__main__":

applications/zero_shot_text_classification/run_train.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from dataclasses import dataclass, field
1616

1717
import paddle
18+
from paddle.metric import Accuracy
1819
from paddle.static import InputSpec
1920
from sklearn.metrics import f1_score
2021
from utils import UTCLoss, read_local_dataset
@@ -39,6 +40,7 @@ class DataArguments:
3940
train_file: str = field(default="train.txt", metadata={"help": "Train dataset file name."})
4041
dev_file: str = field(default="dev.txt", metadata={"help": "Dev dataset file name."})
4142
threshold: float = field(default=0.5, metadata={"help": "The threshold to produce predictions."})
43+
single_label: str = field(default=False, metadata={"help": "Predict exactly one label per sample."})
4244

4345

4446
@dataclass
@@ -92,10 +94,20 @@ def main():
9294
)
9395

9496
# Define the metric function.
95-
def compute_metrics(eval_preds):
97+
def compute_metrics_single_label(eval_preds):
9698
labels = paddle.to_tensor(eval_preds.label_ids, dtype="int64")
9799
preds = paddle.to_tensor(eval_preds.predictions)
100+
preds = paddle.nn.functional.softmax(preds, axis=-1)
101+
labels = paddle.argmax(labels, axis=-1)
102+
metric = Accuracy()
103+
correct = metric.compute(preds, labels)
104+
metric.update(correct)
105+
acc = metric.accumulate()
106+
return {"accuracy": acc}
98107

108+
def compute_metrics(eval_preds):
109+
labels = paddle.to_tensor(eval_preds.label_ids, dtype="int64")
110+
preds = paddle.to_tensor(eval_preds.predictions)
99111
preds = paddle.nn.functional.sigmoid(preds)
100112
preds = preds[labels != -100].numpy()
101113
labels = labels[labels != -100].numpy()
@@ -113,7 +125,7 @@ def compute_metrics(eval_preds):
113125
train_dataset=train_ds,
114126
eval_dataset=dev_ds,
115127
callbacks=None,
116-
compute_metrics=compute_metrics,
128+
compute_metrics=compute_metrics_single_label if data_args.single_label else compute_metrics,
117129
)
118130

119131
# Training.

0 commit comments

Comments
 (0)