@@ -252,11 +252,13 @@ def __init__(
252252 entity_label_types : Union [str , Sequence [str ], dict [str , Optional [set [str ]]]],
253253 entity_pair_labels : Optional [set [tuple [str , str ]]] = None ,
254254 entity_threshold : Optional [float ] = None ,
255+ max_allowed_tokens_between_entities : Optional [int ] = 20 ,
256+ max_surrounding_context_length : Optional [int ] = 10 ,
255257 cross_augmentation : bool = True ,
256258 encoding_strategy : EncodingStrategy = TypedEntityMarker (),
257259 zero_tag_value : str = "O" ,
258260 allow_unk_tag : bool = True ,
259- ** classifierargs ,
261+ ** classifierargs : Any ,
260262 ) -> None :
261263 """Initializes a `RelationClassifier`.
262264
@@ -267,6 +269,8 @@ def __init__(
267269 entity_label_types: A label type or sequence of label types of the required relation entities. You can also specify a label filter in a dictionary with the label type as key and the valid entity labels as values in a set. E.g. to use only 'PER' and 'ORG' labels from a NER-tagger: `{'ner': {'PER', 'ORG'}}`. To use all labels from 'ner', pass 'ner'.
268270 entity_pair_labels: A set of valid relation entity pair combinations, used as relation candidates. Specify valid entity pairs in a set of tuples of labels (<HEAD>, <TAIL>). E.g. for the `born_in` relation, only relations from 'PER' to 'LOC' make sense. Here, relations from 'PER' to 'PER' are not meaningful, so it is advised to specify the `entity_pair_labels` as `{('PER', 'ORG')}`. This setting may help to reduce the number of relation candidates. Leaving this parameter as `None` (default) disables the relation-candidate-filter, i.e. the model classifies the relation for each entity pair in the cross product of *all* entity pairs (inefficient).
269271 entity_threshold: Only pre-labelled entities above this threshold are taken into account by the model.
272+ max_allowed_tokens_between_entities: The maximum allowed number of allowed tokens between entities. All other entity pairs are filtered from consideration. If `None`, the filter will be disabled.
273+ max_surrounding_context_length: The maximum length of context around entity pairs that will be considered. The context, in between the entity pairs will always be included. If `None`, the filter will be disabled.
270274 cross_augmentation: If `True`, use cross augmentation to transform `Sentence`s into `EncodedSentenece`s. When cross augmentation is enabled, the transformation functions, e.g. `transform_corpus`, generate an encoded sentence for each entity pair in the cross product of all entities in the original sentence. When disabling cross augmentation, the transform functions only generate encoded sentences for each gold relation annotation in the original sentence.
271275 encoding_strategy: An instance of a class conforming the :class:`EncodingStrategy` protocol
272276 zero_tag_value: The label to use for out-of-class relations
@@ -302,6 +306,8 @@ def __init__(
302306 self .entity_pair_labels = entity_pair_labels
303307
304308 self .entity_threshold = entity_threshold
309+ self .max_allowed_tokens_between_entities = max_allowed_tokens_between_entities
310+ self .max_surrounding_context_length = max_surrounding_context_length
305311 self .cross_augmentation = cross_augmentation
306312 self .encoding_strategy = encoding_strategy
307313
@@ -393,12 +399,41 @@ def _entity_pair_permutations(
393399
394400 yield head , tail , gold_label
395401
402+ @staticmethod
403+ def _truncate_context_around_entities (
404+ encoded_sentence_tokens : list [str ],
405+ head_idx : int ,
406+ tail_idx : int ,
407+ context_length : int ,
408+ ) -> list [str ]:
409+ """Truncates the encoded sentence to include the head and tail entity and their surrounding context.
410+
411+ The context, in between the entity pairs will always be included.
412+
413+ Args:
414+ encoded_sentence_tokens: The list of tokens corresponding to the encoded sentence.
415+ head_idx: The index of the head entity in the token list.
416+ tail_idx: The index of the tail entity in the token list.
417+ context_length: The maximum number of tokens to include as surrounding context around the head and tail entities.
418+
419+ Returns:
420+ The tokens of the truncated sentence.
421+ """
422+ begin_slice : int = min (head_idx , tail_idx )
423+ end_slice : int = max (head_idx , tail_idx )
424+
425+ # Preserve context around the entities. Always include their in-between context.
426+ begin_slice = max (begin_slice - context_length , 0 )
427+ end_slice = min (end_slice + context_length + 1 , len (encoded_sentence_tokens ))
428+
429+ return encoded_sentence_tokens [begin_slice :end_slice ]
430+
396431 def _encode_sentence (
397432 self ,
398433 head : _Entity ,
399434 tail : _Entity ,
400435 gold_label : Optional [str ] = None ,
401- ) -> EncodedSentence :
436+ ) -> Optional [ EncodedSentence ] :
402437 """Returns a new Sentence object with masked/marked head and tail spans according to the encoding strategy.
403438
404439 If provided, the encoded sentence also has the corresponding gold label annotation from :attr:`~label_type`.
@@ -414,6 +449,12 @@ def _encode_sentence(
414449 original_sentence : Sentence = head .span .sentence
415450 assert original_sentence is tail .span .sentence , "The head and tail need to come from the same sentence."
416451
452+ # Sanity check: Do not create a labeled span if one entity contains the other
453+ if head .span [0 ].idx <= tail .span [0 ].idx and head .span [- 1 ].idx >= tail .span [- 1 ].idx :
454+ return None
455+ if head .span [0 ].idx >= tail .span [0 ].idx and head .span [- 1 ].idx <= tail .span [- 1 ].idx :
456+ return None
457+
417458 # Pre-compute non-leading head and tail tokens for entity masking
418459 non_leading_head_tokens : list [Token ] = head .span .tokens [1 :]
419460 non_leading_tail_tokens : list [Token ] = tail .span .tokens [1 :]
@@ -422,11 +463,15 @@ def _encode_sentence(
422463 # since there may be multiple occurrences of the same entity mentioned in the sentence.
423464 # Therefore, we use the span's position in the sentence.
424465 encoded_sentence_tokens : list [str ] = []
466+ head_idx : Optional [int ] = None
467+ tail_idx : Optional [int ] = None
425468 for token in original_sentence :
426469 if token is head .span [0 ]:
470+ head_idx = len (encoded_sentence_tokens )
427471 encoded_sentence_tokens .append (self .encoding_strategy .encode_head (head .span , head .label ))
428472
429473 elif token is tail .span [0 ]:
474+ tail_idx = len (encoded_sentence_tokens )
430475 encoded_sentence_tokens .append (self .encoding_strategy .encode_tail (tail .span , tail .label ))
431476
432477 elif all (
@@ -435,6 +480,27 @@ def _encode_sentence(
435480 ):
436481 encoded_sentence_tokens .append (token .text )
437482
483+ msg : str
484+ if head_idx is None :
485+ msg = f"The head entity ({ head !r} ) is not located inside the original sentence ({ original_sentence !r} )."
486+ raise AssertionError (msg )
487+ if tail_idx is None :
488+ msg = f"The tail entity ({ tail !r} ) is not located inside the original sentence ({ original_sentence !r} )."
489+ raise AssertionError (msg )
490+
491+ # Filter cases in which the distance between the two entities is too large
492+ if (
493+ self .max_allowed_tokens_between_entities is not None
494+ and abs (head_idx - tail_idx ) > self .max_allowed_tokens_between_entities
495+ ):
496+ return None
497+
498+ # Remove excess tokens left and right of entity pair to make encoded sentence shorter
499+ if self .max_surrounding_context_length is not None :
500+ encoded_sentence_tokens = self ._truncate_context_around_entities (
501+ encoded_sentence_tokens , head_idx , tail_idx , self .max_surrounding_context_length
502+ )
503+
438504 # Create masked sentence
439505 encoded_sentence : EncodedSentence = EncodedSentence (
440506 " " .join (encoded_sentence_tokens ), use_tokenizer = SpaceTokenizer ()
@@ -445,6 +511,7 @@ def _encode_sentence(
445511 # Using the sentence label instead of annotating a separate `Relation` object is easier to manage since,
446512 # during prediction, the forward pass does not need any knowledge about the entities in the sentence.
447513 encoded_sentence .add_label (typename = self .label_type , value = gold_label , score = 1.0 )
514+
448515 encoded_sentence .copy_context_from_sentence (original_sentence )
449516 return encoded_sentence
450517
@@ -469,13 +536,15 @@ def _encode_sentence_for_inference(
469536 Returns: Encoded sentences annotated with their gold relation and the corresponding relation in the original sentence
470537 """
471538 for head , tail , gold_label in self ._entity_pair_permutations (sentence ):
472- masked_sentence : EncodedSentence = self ._encode_sentence (
539+ masked_sentence : Optional [ EncodedSentence ] = self ._encode_sentence (
473540 head = head ,
474541 tail = tail ,
475542 gold_label = gold_label if gold_label is not None else self .zero_tag_value ,
476543 )
477544 original_relation : Relation = Relation (first = head .span , second = tail .span )
478- yield masked_sentence , original_relation
545+
546+ if masked_sentence is not None :
547+ yield masked_sentence , original_relation
479548
480549 def _encode_sentence_for_training (self , sentence : Sentence ) -> Iterator [EncodedSentence ]:
481550 """Create Encoded Sentences and Relation pairs for Training.
@@ -492,13 +561,14 @@ def _encode_sentence_for_training(self, sentence: Sentence) -> Iterator[EncodedS
492561 else :
493562 continue # Skip generated data points that do not express an originally annotated relation
494563
495- masked_sentence : EncodedSentence = self ._encode_sentence (
564+ masked_sentence : Optional [ EncodedSentence ] = self ._encode_sentence (
496565 head = head ,
497566 tail = tail ,
498567 gold_label = gold_label ,
499568 )
500569
501- yield masked_sentence
570+ if masked_sentence is not None :
571+ yield masked_sentence
502572
503573 def transform_sentence (self , sentences : Union [Sentence , list [Sentence ]]) -> list [EncodedSentence ]:
504574 """Transforms sentences into encoded sentences specific to the `RelationClassifier`.
@@ -702,6 +772,8 @@ def _get_state_dict(self) -> dict[str, Any]:
702772 "entity_label_types" : self .entity_label_types ,
703773 "entity_pair_labels" : self .entity_pair_labels ,
704774 "entity_threshold" : self .entity_threshold ,
775+ "max_allowed_tokens_between_entities" : self .max_allowed_tokens_between_entities ,
776+ "max_surrounding_context_length" : self .max_surrounding_context_length ,
705777 "cross_augmentation" : self .cross_augmentation ,
706778 "encoding_strategy" : self .encoding_strategy ,
707779 "zero_tag_value" : self .zero_tag_value ,
@@ -719,6 +791,8 @@ def _init_model_with_state_dict(cls, state: dict[str, Any], **kwargs):
719791 entity_label_types = state ["entity_label_types" ],
720792 entity_pair_labels = state ["entity_pair_labels" ],
721793 entity_threshold = state ["entity_threshold" ],
794+ max_allowed_tokens_between_entities = state .get ("max_allowed_tokens_between_entities" ),
795+ max_surrounding_context_length = state .get ("max_surrounding_context_length" ),
722796 cross_augmentation = state ["cross_augmentation" ],
723797 encoding_strategy = state ["encoding_strategy" ],
724798 zero_tag_value = state ["zero_tag_value" ],
0 commit comments