Skip to content

Commit 176a50c

Browse files
committed
update sacr parsing
1 parent ffd3013 commit 176a50c

File tree

1 file changed

+52
-38
lines changed

1 file changed

+52
-38
lines changed

tibert/bertcoref.py

Lines changed: 52 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -332,9 +332,9 @@ def torch_call(self, features) -> Union[dict, BatchEncoding]:
332332
# same length yet.
333333
return_tensors=None,
334334
)
335-
self.tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = (
336-
warning_state
337-
)
335+
self.tokenizer.deprecation_warnings[
336+
"Asking-to-pad-a-fast-tokenizer"
337+
] = warning_state
338338

339339
# keep encoding info
340340
batch._encodings = [f.encodings[0] for f in features]
@@ -532,50 +532,64 @@ def from_sacr_dir(
532532

533533
for fpath in tqdm(paths):
534534
with open(fpath, **kwargs) as f:
535-
text = f.read().replace("\n", " ")
535+
text = f.read()
536+
text = re.sub(r"#COLOR:.*\n", "", text)
537+
text = re.sub(r"#TOKENIZATION-TYPE:.*\n", "", text)
538+
text = text.replace("\n", " ")
536539

537540
def parse(text: str) -> Tuple[List[str], Dict[str, List[Mention]]]:
538-
splitted = re.split(r"({T[0-9]+:EN=\".*?\" [^{]*})", text)
539-
540-
if len(splitted) == 1:
541-
return m_tokenizer.tokenize(text, escape=False), {}
542-
541+
# SACR format example:
542+
# {T109:EN="p PER" Le nouveau-né} s’agite dans {T109:EN="p PER" son} berceau.
543+
# This format can be nested
543544
tokens: List[str] = []
544545
# { id => chain }
545546
chains: Dict[str, List[Mention]] = defaultdict(list)
546547

547-
# SACR format example:
548-
# {T109:EN="p PER" Le nouveau-né} s’agite dans {T109:EN="p PER" son} berceau.
549-
#
550-
# split the text using a pattern that matches text
551-
# between braces. The text variable has,
552-
# alternatively, either the text between braces or
553-
# regulare text.
554-
for i, text in enumerate(splitted):
555-
# regular text
556-
if i % 2 == 0:
557-
text_tokens = m_tokenizer.tokenize(text, escape=False)
558-
tokens += text_tokens
559-
# text inside braces represents a coreference mention
560-
else:
561-
text_match = re.search(r"{T([0-9]+):EN=\".*?\" (.*)}", text)
562-
assert not text_match is None
563-
text_tokens, subchains = parse(text_match.group(2))
564-
565-
for chain_key, mentions in subchains.items():
566-
chains[chain_key] += [
567-
m.shifted(len(tokens)) for m in mentions
568-
]
569-
570-
chains[text_match.group(1)].append(
571-
Mention(
572-
text_tokens,
573-
len(tokens),
574-
len(tokens) + len(text_tokens),
548+
while True:
549+
550+
m = re.search(r"{([^:]+):EN=\".*?\" ", text)
551+
552+
if m is None:
553+
tokens += m_tokenizer.tokenize(text, escape=False)
554+
break
555+
556+
# add tokens seen so far
557+
tokens += m_tokenizer.tokenize(text[: m.start()])
558+
559+
# look for the end of the chain
560+
open_count = 0
561+
chain_end = None
562+
for i, c in enumerate(text[m.start() + 1 :]):
563+
if c == "{":
564+
open_count += 1
565+
elif c == "}":
566+
if open_count == 0:
567+
chain_end = m.start() + 1 + i
568+
break
569+
open_count -= 1
570+
if chain_end is None:
571+
raise ValueError(f"Unbalanced braces found in {fpath}.")
572+
573+
# recursively parse mention and update tokens
574+
# and chains
575+
subtokens, subchains = parse(text[m.end() : chain_end])
576+
for subchain_id, subchain in subchains.items():
577+
for submention in subchain:
578+
chains[subchain_id].append(
579+
submention.shifted(+len(tokens))
575580
)
581+
582+
# deal with current mention
583+
chain_id = m.group(1)
584+
chains[chain_id].append(
585+
Mention(
586+
subtokens, len(tokens), len(tokens) + len(subtokens)
576587
)
588+
)
577589

578-
tokens += text_tokens
590+
# move forward
591+
tokens += subtokens
592+
text = text[chain_end + 1 :]
579593

580594
return tokens, chains
581595

0 commit comments

Comments
 (0)