diff --git a/seqeval/metrics/sequence_labeling.py b/seqeval/metrics/sequence_labeling.py index b5807a3..63bbf30 100644 --- a/seqeval/metrics/sequence_labeling.py +++ b/seqeval/metrics/sequence_labeling.py @@ -159,15 +159,15 @@ def get_entities(seq, suffix=False): """ def _validate_chunk(chunk, suffix): - if chunk in ['O', 'B', 'I', 'E', 'S']: + if chunk in ['O', 'B', 'I', 'E', 'S', 'M']: return if suffix: - if not chunk.endswith(('-B', '-I', '-E', '-S')): + if not chunk.endswith(('-B', '-I', '-E', '-S', '-M')): warnings.warn('{} seems not to be NE tag.'.format(chunk)) else: - if not chunk.startswith(('B-', 'I-', 'E-', 'S-')): + if not chunk.startswith(('B-', 'I-', 'E-', 'S-', 'M')): warnings.warn('{} seems not to be NE tag.'.format(chunk)) # for nested list @@ -229,6 +229,12 @@ def end_of_chunk(prev_tag, tag, prev_type, type_): chunk_end = True if prev_tag == 'I' and tag == 'O': chunk_end = True + if prev_tag == 'M' and tag == 'B': + chunk_end = True + if prev_tag == 'M' and tag == 'S': + chunk_end = True + if prev_tag == 'M' and tag == 'O': + chunk_end = True if prev_tag != 'O' and prev_tag != '.' and prev_type != type_: chunk_end = True @@ -267,6 +273,10 @@ def start_of_chunk(prev_tag, tag, prev_type, type_): chunk_start = True if prev_tag == 'O' and tag == 'I': chunk_start = True + if prev_tag == 'M' and tag == 'E': + chunk_start = True + if prev_tag == 'M' and tag == 'I': + chunk_start = True if tag != 'O' and tag != '.' and prev_type != type_: chunk_start = True