Skip to content

Commit cf523c7

Browse files
authored
fix ie_utils import (#4066)
* fix import * fix import * fix import
1 parent b1bfbf2 commit cf523c7

File tree

4 files changed

+34
-53
lines changed

4 files changed

+34
-53
lines changed

applications/information_extraction/document/finetune.py

Lines changed: 1 addition & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,14 @@
2121
from utils import convert_example, reader
2222

2323
from paddlenlp.datasets import load_dataset
24-
from paddlenlp.metrics import SpanEvaluator
2524
from paddlenlp.trainer import (
2625
CompressionArguments,
2726
PdArgumentParser,
2827
Trainer,
2928
get_last_checkpoint,
3029
)
3130
from paddlenlp.transformers import UIEX, AutoTokenizer, export_model
31+
from paddlenlp.utils.ie_utils import compute_metrics, uie_loss_func
3232
from paddlenlp.utils.log import logger
3333

3434

@@ -113,31 +113,6 @@ def main():
113113
train_ds = train_ds.map(trans_fn)
114114
dev_ds = dev_ds.map(trans_fn)
115115

116-
criterion = paddle.nn.BCELoss()
117-
118-
def uie_loss_func(outputs, labels):
119-
start_ids, end_ids = labels
120-
start_prob, end_prob = outputs
121-
start_ids = paddle.cast(start_ids, "float32")
122-
end_ids = paddle.cast(end_ids, "float32")
123-
loss_start = criterion(start_prob, start_ids)
124-
loss_end = criterion(end_prob, end_ids)
125-
loss = (loss_start + loss_end) / 2.0
126-
return loss
127-
128-
def compute_metrics(p):
129-
metric = SpanEvaluator()
130-
start_prob, end_prob = p.predictions
131-
start_ids, end_ids = p.label_ids
132-
metric.reset()
133-
134-
num_correct, num_infer, num_label = metric.compute(start_prob, end_prob, start_ids, end_ids)
135-
metric.update(num_correct, num_infer, num_label)
136-
precision, recall, f1 = metric.accumulate()
137-
metric.reset()
138-
139-
return {"precision": precision, "recall": recall, "f1": f1}
140-
141116
trainer = Trainer(
142117
model=model,
143118
criterion=uie_loss_func,

applications/information_extraction/text/finetune.py

Lines changed: 1 addition & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
get_last_checkpoint,
3232
)
3333
from paddlenlp.transformers import UIE, UIEM, AutoTokenizer, export_model
34+
from paddlenlp.utils.ie_utils import compute_metrics, uie_loss_func
3435
from paddlenlp.utils.log import logger
3536

3637

@@ -141,31 +142,6 @@ def main():
141142

142143
data_collator = DataCollatorWithPadding(tokenizer)
143144

144-
criterion = paddle.nn.BCELoss()
145-
146-
def uie_loss_func(outputs, labels):
147-
start_ids, end_ids = labels
148-
start_prob, end_prob = outputs
149-
start_ids = paddle.cast(start_ids, "float32")
150-
end_ids = paddle.cast(end_ids, "float32")
151-
loss_start = criterion(start_prob, start_ids)
152-
loss_end = criterion(end_prob, end_ids)
153-
loss = (loss_start + loss_end) / 2.0
154-
return loss
155-
156-
def compute_metrics(p):
157-
metric = SpanEvaluator()
158-
start_prob, end_prob = p.predictions
159-
start_ids, end_ids = p.label_ids
160-
metric.reset()
161-
162-
num_correct, num_infer, num_label = metric.compute(start_prob, end_prob, start_ids, end_ids)
163-
metric.update(num_correct, num_infer, num_label)
164-
precision, recall, f1 = metric.accumulate()
165-
metric.reset()
166-
167-
return {"precision": precision, "recall": recall, "f1": f1}
168-
169145
trainer = Trainer(
170146
model=model,
171147
criterion=uie_loss_func,

paddlenlp/taskflow/task.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ def _get_inference_model(self):
246246
self._param_updated = True
247247
if os.path.exists(cache_info_path) and open(cache_info_path).read()[:-8] == md5:
248248
self._param_updated = False
249-
elif self.task == "information_extraction":
249+
elif self.task == "information_extraction" and self.model != "uie-data-distill-gp":
250250
# UIE related models are moved to paddlenlp.transformers after v2.4.5
251251
# So we convert the parameter key names for compatibility
252252
# This check will be discard in future
@@ -257,7 +257,9 @@ def _get_inference_model(self):
257257
prefix_map = {"UIE": "ernie", "UIEM": "ernie_m", "UIEX": "ernie_layout"}
258258
new_state_dict = {}
259259
for name, param in model_state.items():
260-
if "encoder.encoder" in name:
260+
if "ernie" in name:
261+
new_state_dict[name] = param
262+
elif "encoder.encoder" in name:
261263
trans_name = name.replace("encoder.encoder", prefix_map[self._init_class] + ".encoder")
262264
new_state_dict[trans_name] = param
263265
elif "encoder" in name:

paddlenlp/utils/ie_utils.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,10 @@
1616
from io import BytesIO
1717

1818
import numpy as np
19+
import paddle
1920
from PIL import Image
2021

22+
from ..metrics import SpanEvaluator
2123
from .image_utils import NormalizeImage, Permute, ResizeImage
2224

2325
resize_func = ResizeImage(target_size=224, interp=1)
@@ -112,3 +114,29 @@ def compare(a, b, schema_lang="ch"):
112114
relation_type = prefix
113115
relation_type_dict.setdefault(relation_type, []).append(relation_data[i][1])
114116
return relation_type_dict
117+
118+
119+
def uie_loss_func(outputs, labels):
120+
criterion = paddle.nn.BCELoss()
121+
start_ids, end_ids = labels
122+
start_prob, end_prob = outputs
123+
start_ids = paddle.cast(start_ids, "float32")
124+
end_ids = paddle.cast(end_ids, "float32")
125+
loss_start = criterion(start_prob, start_ids)
126+
loss_end = criterion(end_prob, end_ids)
127+
loss = (loss_start + loss_end) / 2.0
128+
return loss
129+
130+
131+
def compute_metrics(p):
132+
metric = SpanEvaluator()
133+
start_prob, end_prob = p.predictions
134+
start_ids, end_ids = p.label_ids
135+
metric.reset()
136+
137+
num_correct, num_infer, num_label = metric.compute(start_prob, end_prob, start_ids, end_ids)
138+
metric.update(num_correct, num_infer, num_label)
139+
precision, recall, f1 = metric.accumulate()
140+
metric.reset()
141+
142+
return {"precision": precision, "recall": recall, "f1": f1}

0 commit comments

Comments
 (0)