@@ -332,9 +332,9 @@ def torch_call(self, features) -> Union[dict, BatchEncoding]:
332332 # same length yet.
333333 return_tensors = None ,
334334 )
335- self .tokenizer .deprecation_warnings ["Asking-to-pad-a-fast-tokenizer" ] = (
336- warning_state
337- )
335+ self .tokenizer .deprecation_warnings [
336+ "Asking-to-pad-a-fast-tokenizer"
337+ ] = warning_state
338338
339339 # keep encoding info
340340 batch ._encodings = [f .encodings [0 ] for f in features ]
@@ -532,50 +532,64 @@ def from_sacr_dir(
532532
533533 for fpath in tqdm (paths ):
534534 with open (fpath , ** kwargs ) as f :
535- text = f .read ().replace ("\n " , " " )
535+ text = f .read ()
536+ text = re .sub (r"#COLOR:.*\n" , "" , text )
537+ text = re .sub (r"#TOKENIZATION-TYPE:.*\n" , "" , text )
538+ text = text .replace ("\n " , " " )
536539
537540 def parse (text : str ) -> Tuple [List [str ], Dict [str , List [Mention ]]]:
538- splitted = re .split (r"({T[0-9]+:EN=\".*?\" [^{]*})" , text )
539-
540- if len (splitted ) == 1 :
541- return m_tokenizer .tokenize (text , escape = False ), {}
542-
541+ # SACR format example:
542+ # {T109:EN="p PER" Le nouveau-né} s’agite dans {T109:EN="p PER" son} berceau.
543+ # This format can be nested
543544 tokens : List [str ] = []
544545 # { id => chain }
545546 chains : Dict [str , List [Mention ]] = defaultdict (list )
546547
547- # SACR format example:
548- # {T109:EN="p PER" Le nouveau-né} s’agite dans {T109:EN="p PER" son} berceau.
549- #
550- # split the text using a pattern that matches text
551- # between braces. The text variable has,
552- # alternatively, either the text between braces or
553- # regulare text.
554- for i , text in enumerate (splitted ):
555- # regular text
556- if i % 2 == 0 :
557- text_tokens = m_tokenizer .tokenize (text , escape = False )
558- tokens += text_tokens
559- # text inside braces represents a coreference mention
560- else :
561- text_match = re .search (r"{T([0-9]+):EN=\".*?\" (.*)}" , text )
562- assert not text_match is None
563- text_tokens , subchains = parse (text_match .group (2 ))
564-
565- for chain_key , mentions in subchains .items ():
566- chains [chain_key ] += [
567- m .shifted (len (tokens )) for m in mentions
568- ]
569-
570- chains [text_match .group (1 )].append (
571- Mention (
572- text_tokens ,
573- len (tokens ),
574- len (tokens ) + len (text_tokens ),
548+ while True :
549+
550+ m = re .search (r"{([^:]+):EN=\".*?\" " , text )
551+
552+ if m is None :
553+ tokens += m_tokenizer .tokenize (text , escape = False )
554+ break
555+
556+ # add tokens seen so far
557+ tokens += m_tokenizer .tokenize (text [: m .start ()])
558+
559+ # look for the end of the chain
560+ open_count = 0
561+ chain_end = None
562+ for i , c in enumerate (text [m .start () + 1 :]):
563+ if c == "{" :
564+ open_count += 1
565+ elif c == "}" :
566+ if open_count == 0 :
567+ chain_end = m .start () + 1 + i
568+ break
569+ open_count -= 1
570+ if chain_end is None :
571+ raise ValueError (f"Unbalanced braces found in { fpath } ." )
572+
573+ # recursively parse mention and update tokens
574+ # and chains
575+ subtokens , subchains = parse (text [m .end () : chain_end ])
576+ for subchain_id , subchain in subchains .items ():
577+ for submention in subchain :
578+ chains [subchain_id ].append (
579+ submention .shifted (+ len (tokens ))
575580 )
581+
582+ # deal with current mention
583+ chain_id = m .group (1 )
584+ chains [chain_id ].append (
585+ Mention (
586+ subtokens , len (tokens ), len (tokens ) + len (subtokens )
576587 )
588+ )
577589
578- tokens += text_tokens
590+ # move forward
591+ tokens += subtokens
592+ text = text [chain_end + 1 :]
579593
580594 return tokens , chains
581595
0 commit comments