Skip to content

Commit 0f3c956

Browse files
committed
修复由于分词词表带来的切分不一致问题 #466
1 parent aae913b commit 0f3c956

File tree

3 files changed

+60
-5
lines changed

3 files changed

+60
-5
lines changed

ltp/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# -*- coding: utf-8 -*_
33
# Author: Yunlong Feng <ylfeng@ir.hit.edu.cn>
44

5-
__version__ = '4.1.3'
5+
__version__ = '4.1.3.post1'
66

77
from . import const
88
from . import nn, utils

ltp/frontend.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@ def __init__(self, path: str = 'small', device=None, **kwargs):
141141
self.model.eval()
142142

143143
self.seg_vocab = ckpt.get('seg', [WORD_MIDDLE, WORD_START])
144+
self.seg_vocab_dict = {tag: idx for idx, tag in enumerate(self.seg_vocab)}
144145
self.pos_vocab = ckpt.get('pos', [])
145146
self.ner_vocab = ckpt.get('ner', [])
146147
self.dep_vocab = ckpt.get('dep', [])
@@ -255,10 +256,10 @@ def seg(self, inputs: Union[List[str], List[List[str]]], truncation: bool = True
255256
matches = self.seg_with_dict(inputs, tokenized, batch_prefix)
256257
for sent_match, sent_seg in zip(matches, seg):
257258
for start, end in sent_match:
258-
sent_seg[start] = 0
259-
sent_seg[start + 1:end] = 1
259+
sent_seg[start] = self.seg_vocab_dict[WORD_START]
260+
sent_seg[start + 1:end] = self.seg_vocab_dict[WORD_MIDDLE]
260261
if end < len(sent_seg):
261-
sent_seg[end] = 0
262+
sent_seg[end] = self.seg_vocab_dict[WORD_START]
262263

263264
if is_preseged:
264265
sentences = inputs

ltp/multitask.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,11 @@
77
from argparse import ArgumentParser
88
from collections import OrderedDict
99

10+
import numpy
1011
import torch
1112
import torch.utils.data
1213
from pytorch_lightning import Trainer
14+
from tqdm import tqdm
1315

1416
import ltp
1517
from ltp import (
@@ -20,7 +22,7 @@
2022
from ltp.data import dataset as datasets
2123
from ltp.data.utils import collate, MultiTaskDataloader
2224
from ltp.transformer_multitask import TransformerMultiTask as Model
23-
from ltp.utils import TaskInfo, common_train, tune_train
25+
from ltp.utils import TaskInfo, common_train, tune_train, map2device, convert2npy
2426
from ltp.utils import deploy_model
2527

2628
os.environ['TOKENIZERS_PARALLELISM'] = 'true'
@@ -190,6 +192,55 @@ def configure_optimizers(self: Model):
190192
)
191193

192194

195+
def build_ner_distill_dataset(args):
196+
model = Model.load_from_checkpoint(
197+
args.resume_from_checkpoint, hparams=args
198+
)
199+
200+
model.eval()
201+
model.freeze()
202+
203+
dataset, metric = task_named_entity_recognition.build_dataset(model, args.ner_data_dir, task_info.task_name)
204+
train_dataloader = torch.utils.data.DataLoader(
205+
dataset[datasets.Split.TRAIN],
206+
batch_size=args.batch_size,
207+
collate_fn=collate,
208+
num_workers=args.num_workers
209+
)
210+
211+
output = os.path.join(args.ner_data_dir, task_info.task_name, 'output.npz')
212+
213+
if torch.cuda.is_available():
214+
model.cuda()
215+
map2cpu = lambda x: map2device(x)
216+
map2cuda = lambda x: map2device(x, model.device)
217+
else:
218+
map2cpu = lambda x: x
219+
map2cuda = lambda x: x
220+
221+
with torch.no_grad():
222+
batchs = []
223+
for batch in tqdm(train_dataloader):
224+
batch = map2cuda(batch)
225+
logits = model.forward(task='ner', **batch).logits
226+
batch.update(logits=logits)
227+
batchs.append(map2cpu(batch))
228+
try:
229+
numpy.savez(
230+
output,
231+
data=convert2npy(batchs),
232+
extra=convert2npy({
233+
'transitions': model.ner_classifier.crf.transitions,
234+
'start_transitions': model.ner_classifier.crf.start_transitions,
235+
'end_transitions': model.ner_classifier.crf.end_transitions
236+
})
237+
)
238+
except Exception as e:
239+
numpy.savez(output, data=convert2npy(batchs))
240+
241+
print("Done")
242+
243+
193244
def add_task_specific_args(parent_parser):
194245
parser = ArgumentParser(parents=[parent_parser], add_help=False)
195246
parser.add_argument('--seed', type=int, default=19980524)
@@ -210,6 +261,7 @@ def add_task_specific_args(parent_parser):
210261
parser.add_argument('--dep_data_dir', type=str, default=None)
211262
parser.add_argument('--sdp_data_dir', type=str, default=None)
212263
parser.add_argument('--srl_data_dir', type=str, default=None)
264+
parser.add_argument('--build_ner_dataset', action='store_true')
213265
return parser
214266

215267

@@ -226,6 +278,8 @@ def main():
226278

227279
if args.ltp_model is not None and args.resume_from_checkpoint is not None:
228280
deploy_model(args, args.ltp_version)
281+
elif args.build_ner_dataset:
282+
build_ner_distill_dataset(args)
229283
elif args.tune:
230284
tune_train(args, model_class=Model, task_info=task_info, build_method=build_method)
231285
else:

0 commit comments

Comments
 (0)