Skip to content

Commit b04e126

Browse files
committed
针对inference阶段没有gt_label的情况针对性修复ser_postprocessor以及ser_visualizer中存在的bug.
1 parent d9a3a5e commit b04e126

File tree

3 files changed

+92
-56
lines changed

3 files changed

+92
-56
lines changed

projects/LayoutLMv3/datasets/transforms/formatting.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -98,22 +98,23 @@ def transform(self, results: dict) -> dict:
9898
for key in self.ser_keys:
9999
if key not in results:
100100
continue
101-
value = to_tensor(results[key])
102-
inputs[key] = value
101+
inputs[key] = to_tensor(results[key])
103102
packed_results['inputs'] = inputs
104103

105104
# pack `data_samples`
106105
data_samples = []
107106
for truncation_idx in range(truncation_number):
108107
data_sample = SERDataSample()
109108
gt_label = LabelData()
110-
assert 'labels' in results, 'key `labels` not in results.'
111-
value = to_tensor(results['labels'][truncation_idx])
112-
gt_label.item = value
109+
if results.get('labels', None):
110+
gt_label.item = to_tensor(results['labels'][truncation_idx])
113111
data_sample.gt_label = gt_label
114112
meta = {}
115113
for key in self.meta_keys:
116-
meta[key] = results[key]
114+
if key == 'truncation_word_ids':
115+
meta[key] = results[key][truncation_idx]
116+
else:
117+
meta[key] = results[key]
117118
data_sample.set_metainfo(meta)
118119
data_samples.append(data_sample)
119120
packed_results['data_samples'] = data_samples

projects/LayoutLMv3/models/ser_postprocessor.py

Lines changed: 41 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,10 @@
1616
class SERPostprocessor(nn.Module):
1717
"""PostProcessor for SER."""
1818

19-
def __init__(self,
20-
classes: Union[tuple, list],
21-
ignore_index: int = -100) -> None:
19+
def __init__(self, classes: Union[tuple, list]) -> None:
2220
super().__init__()
2321
self.other_label_name = find_other_label_name_of_biolabel(classes)
2422
self.id2biolabel = self._generate_id2biolabel_map(classes)
25-
self.ignore_index = ignore_index
2623
self.softmax = nn.Softmax(dim=-1)
2724

2825
def _generate_id2biolabel_map(self, classes: Union[tuple, list]) -> Dict:
@@ -43,42 +40,62 @@ def _generate_id2biolabel_map(self, classes: Union[tuple, list]) -> Dict:
4340
def __call__(self, outputs: torch.Tensor,
4441
data_samples: Sequence[SERDataSample]
4542
) -> Sequence[SERDataSample]:
43+
# merge several truncation data_sample to one data_sample
44+
assert all('truncation_word_ids' in d for d in data_samples), \
45+
'The key `truncation_word_ids` should be specified' \
46+
'in PackSERInputs.'
47+
truncation_word_ids = []
48+
for data_sample in data_samples:
49+
truncation_word_ids.append(data_sample.pop('truncation_word_ids'))
50+
merged_data_sample = copy.deepcopy(data_samples[0])
51+
merged_data_sample.set_metainfo(
52+
dict(truncation_word_ids=truncation_word_ids))
53+
flattened_word_ids = [
54+
word_id for word_ids in truncation_word_ids for word_id in word_ids
55+
]
56+
4657
# convert outputs dim from (truncation_num, max_length, label_num)
4758
# to (truncation_num * max_length, label_num)
4859
outputs = outputs.cpu().detach()
49-
truncation_num = outputs.size(0)
5060
outputs = torch.reshape(outputs, (-1, outputs.size(-1)))
51-
# merge gt label ids from data_samples
52-
gt_label_ids = [
53-
data_samples[truncation_idx].gt_label.item
54-
for truncation_idx in range(truncation_num)
55-
]
56-
gt_label_ids = torch.cat(gt_label_ids, dim=0).cpu().detach().numpy()
5761
# get pred label ids/scores from outputs
5862
probs = self.softmax(outputs)
5963
max_value, max_idx = torch.max(probs, -1)
6064
pred_label_ids = max_idx.numpy()
6165
pred_label_scores = max_value.numpy()
62-
# select valid token and convert iid to biolabel
63-
gt_biolabels = [
64-
self.id2biolabel[g] for (g, p) in zip(gt_label_ids, pred_label_ids)
65-
if g != self.ignore_index
66-
]
66+
67+
# determine whether it is an inference process
68+
if 'item' in data_samples[0].gt_label:
69+
# merge gt label ids from data_samples
70+
gt_label_ids = [
71+
data_sample.gt_label.item for data_sample in data_samples
72+
]
73+
gt_label_ids = torch.cat(
74+
gt_label_ids, dim=0).cpu().detach().numpy()
75+
gt_biolabels = [
76+
self.id2biolabel[g]
77+
for (w, g) in zip(flattened_word_ids, gt_label_ids)
78+
if w is not None
79+
]
80+
# update merged gt_label
81+
merged_data_sample.gt_label.item = gt_biolabels
82+
83+
# inference process do not have item in gt_label,
84+
# so select valid token with flattened_word_ids
85+
# rather than with gt_label_ids like official code.
6786
pred_biolabels = [
68-
self.id2biolabel[p] for (g, p) in zip(gt_label_ids, pred_label_ids)
69-
if g != self.ignore_index
87+
self.id2biolabel[p]
88+
for (w, p) in zip(flattened_word_ids, pred_label_ids)
89+
if w is not None
7090
]
7191
pred_biolabel_scores = [
72-
s for (g, s) in zip(gt_label_ids, pred_label_scores)
73-
if g != self.ignore_index
92+
s for (w, s) in zip(flattened_word_ids, pred_label_scores)
93+
if w is not None
7494
]
7595
# record pred_label
7696
pred_label = LabelData()
7797
pred_label.item = pred_biolabels
7898
pred_label.score = pred_biolabel_scores
79-
# merge several truncation data_sample to one data_sample
80-
merged_data_sample = copy.deepcopy(data_samples[0])
8199
merged_data_sample.pred_label = pred_label
82-
# update merged gt_label
83-
merged_data_sample.gt_label.item = gt_biolabels
100+
84101
return [merged_data_sample]

projects/LayoutLMv3/visualization/ser_visualizer.py

Lines changed: 44 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -91,19 +91,13 @@ def _draw_instances(self, image: np.ndarray, bboxes: Union[np.ndarray,
9191
line_width=self.line_width,
9292
alpha=self.alpha)
9393

94-
# draw gt/pred labels
95-
if gt_labels is not None and pred_labels is not None:
94+
areas = (bboxes[:, 3] - bboxes[:, 1]) * (bboxes[:, 2] - bboxes[:, 0])
95+
scales = _get_adaptive_scales(areas)
96+
positions = (bboxes[:, :2] + bboxes[:, 2:]) // 2
97+
98+
if gt_labels is not None:
9699
gt_tokens_biolabel = gt_labels.item
97100
gt_words_label = []
98-
pred_tokens_biolabel = pred_labels.item
99-
pred_words_label = []
100-
101-
if 'score' in pred_labels:
102-
pred_tokens_biolabel_score = pred_labels.score
103-
pred_words_label_score = []
104-
else:
105-
pred_tokens_biolabel_score = None
106-
pred_words_label_score = None
107101

108102
pre_word_id = None
109103
for idx, cur_word_id in enumerate(word_ids):
@@ -112,36 +106,60 @@ def _draw_instances(self, image: np.ndarray, bboxes: Union[np.ndarray,
112106
gt_words_label_name = gt_tokens_biolabel[idx][2:] \
113107
if gt_tokens_biolabel[idx] != 'O' else 'other'
114108
gt_words_label.append(gt_words_label_name)
109+
pre_word_id = cur_word_id
110+
assert len(gt_words_label) == len(bboxes)
111+
if pred_labels is not None:
112+
pred_tokens_biolabel = pred_labels.item
113+
pred_words_label = []
114+
pred_tokens_biolabel_score = pred_labels.score
115+
pred_words_label_score = []
116+
117+
pre_word_id = None
118+
for idx, cur_word_id in enumerate(word_ids):
119+
if cur_word_id is not None:
120+
if cur_word_id != pre_word_id:
115121
pred_words_label_name = pred_tokens_biolabel[idx][2:] \
116122
if pred_tokens_biolabel[idx] != 'O' else 'other'
117123
pred_words_label.append(pred_words_label_name)
118-
if pred_tokens_biolabel_score is not None:
119-
pred_words_label_score.append(
120-
pred_tokens_biolabel_score[idx])
124+
pred_words_label_score.append(
125+
pred_tokens_biolabel_score[idx])
121126
pre_word_id = cur_word_id
122-
assert len(gt_words_label) == len(bboxes)
123127
assert len(pred_words_label) == len(bboxes)
124128

125-
areas = (bboxes[:, 3] - bboxes[:, 1]) * (
126-
bboxes[:, 2] - bboxes[:, 0])
127-
scales = _get_adaptive_scales(areas)
128-
positions = (bboxes[:, :2] + bboxes[:, 2:]) // 2
129-
129+
# draw gt or pred labels
130+
if gt_labels is not None and pred_labels is not None:
130131
for i, (pos, gt, pred) in enumerate(
131132
zip(positions, gt_words_label, pred_words_label)):
132-
if pred_words_label_score is not None:
133-
score = round(float(pred_words_label_score[i]) * 100, 1)
134-
label_text = f'{gt} | {pred}({score})'
135-
else:
136-
label_text = f'{gt} | {pred}'
137-
133+
score = round(float(pred_words_label_score[i]) * 100, 1)
134+
label_text = f'{gt} | {pred}({score})'
138135
self.draw_texts(
139136
label_text,
140137
pos,
141138
colors=self.label_color if gt == pred else 'r',
142139
font_sizes=int(13 * scales[i]),
143140
vertical_alignments='center',
144141
horizontal_alignments='center')
142+
elif pred_labels is not None:
143+
for i, (pos, pred) in enumerate(zip(positions, pred_words_label)):
144+
score = round(float(pred_words_label_score[i]) * 100, 1)
145+
label_text = f'Pred: {pred}({score})'
146+
self.draw_texts(
147+
label_text,
148+
pos,
149+
colors=self.label_color,
150+
font_sizes=int(13 * scales[i]),
151+
vertical_alignments='center',
152+
horizontal_alignments='center')
153+
elif gt_labels is not None:
154+
for i, (pos, gt) in enumerate(zip(positions, gt_words_label)):
155+
label_text = f'GT: {gt}'
156+
self.draw_texts(
157+
label_text,
158+
pos,
159+
colors=self.label_color,
160+
font_sizes=int(13 * scales[i]),
161+
vertical_alignments='center',
162+
horizontal_alignments='center')
145163

146164
return self.get_image()
147165

0 commit comments

Comments
 (0)