Skip to content
This repository was archived by the owner on Nov 8, 2022. It is now read-only.

Commit b118107

Browse files
authored
Merge pull request #338 from NervanaSystems/alon/master_local
Alon/master local
2 parents de2fda8 + 3776822 commit b118107

File tree

9 files changed

+107
-98
lines changed

9 files changed

+107
-98
lines changed

examples/cross_doc_coref/cross_doc_coref_sieves.py

Lines changed: 45 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,11 @@
1515
# ******************************************************************************
1616

1717
import logging
18+
from typing import List
1819

1920
from nlp_architect import LIBRARY_ROOT
21+
from nlp_architect.common.cdc.cluster import Clusters
22+
from nlp_architect.common.cdc.topics import Topics
2023
from nlp_architect.data.cdc_resources.relations.relation_types_enums import RelationType
2124
from nlp_architect.models.cross_doc_coref.cdc_config import EventConfig, EntityConfig
2225
from nlp_architect.models.cross_doc_coref.cdc_resource import CDCResources
@@ -25,7 +28,24 @@
2528
from nlp_architect.models.cross_doc_sieves import run_event_coref, run_entity_coref
2629

2730

28-
def run_example():
31+
def run_example(cdc_settings):
32+
event_mentions = Topics(LIBRARY_ROOT + '/datasets/ecb/ecb_all_event_mentions.json')
33+
34+
event_clusters = None
35+
if cdc_settings.event_config.run_evaluation:
36+
logger.info('Running event coreference resolution')
37+
event_clusters = run_event_coref(event_mentions, cdc_settings)
38+
39+
entity_mentions = Topics(LIBRARY_ROOT + '/datasets/ecb/ecb_all_entity_mentions.json')
40+
entity_clusters = None
41+
if cdc_settings.entity_config.run_evaluation:
42+
logger.info('Running entity coreference resolution')
43+
entity_clusters = run_entity_coref(entity_mentions, cdc_settings)
44+
45+
return event_clusters, entity_clusters
46+
47+
48+
def create_example_settings():
2949
event_config = EventConfig()
3050
event_config.sieves_order = [
3151
(SieveType.STRICT, RelationType.SAME_HEAD_LEMMA, 0.0),
@@ -34,11 +54,7 @@ def run_example():
3454
(SieveType.RELAX, RelationType.SAME_HEAD_LEMMA_RELAX, 0.5),
3555
]
3656

37-
event_config.gold_mentions_file = LIBRARY_ROOT + \
38-
'/datasets/ecb/ecb_all_event_mentions.json'
39-
4057
entity_config = EntityConfig()
41-
4258
entity_config.sieves_order = [
4359
(SieveType.STRICT, RelationType.SAME_HEAD_LEMMA, 0.0),
4460
(SieveType.VERY_RELAX, RelationType.WIKIPEDIA_REDIRECT_LINK, 0.1),
@@ -47,44 +63,42 @@ def run_example():
4763
(SieveType.VERY_RELAX, RelationType.REFERENT_DICT, 0.5)
4864
]
4965

50-
entity_config.gold_mentions_file = LIBRARY_ROOT + \
51-
'/datasets/ecb/ecb_all_entity_mentions.json'
52-
5366
# CDCResources hold default attribute values that might need to be change,
5467
# (using the defaults values in this example), use to configure attributes
5568
# such as resources files location, output directory, resources init methods and other.
5669
# check in class and see if any attributes require change in your set-up
5770
resource_location = CDCResources()
58-
resources = CDCSettings(resource_location, event_config, entity_config)
71+
return CDCSettings(resource_location, event_config, entity_config)
5972

60-
event_clusters = None
61-
if event_config.run_evaluation:
62-
logger.info('Running event coreference resolution')
63-
event_clusters = run_event_coref(resources)
6473

65-
entity_clusters = None
66-
if entity_config.run_evaluation:
67-
logger.info('Running entity coreference resolution')
68-
entity_clusters = run_entity_coref(resources)
74+
def print_results(clusters: List[Clusters], type: str):
75+
print('-=' + type + ' Clusters=-')
76+
for topic_cluster in clusters:
77+
print('\n\tTopic=' + topic_cluster.topic_id)
78+
for cluster in topic_cluster.clusters_list:
79+
cluster_mentions = list()
80+
for mention in cluster.mentions:
81+
mentions_dict = dict()
82+
mentions_dict['id'] = mention.mention_id
83+
mentions_dict['text'] = mention.tokens_str
84+
cluster_mentions.append(mentions_dict)
85+
86+
print('\t\tCluster(' + str(cluster.coref_chain) + ') Mentions='
87+
+ str(cluster_mentions))
6988

70-
print('-=Cross Document Coref Results=-')
71-
print('-=Event Clusters Mentions=-')
72-
for event_cluster in event_clusters.clusters_list:
73-
print(event_cluster.coref_chain)
74-
for event_mention in event_cluster.mentions:
75-
print(event_mention.mention_id)
76-
print(event_mention.tokens_str)
7789

78-
print('-=Entity Clusters Mentions=-')
79-
for entity_cluster in entity_clusters.clusters_list:
80-
print(entity_cluster.coref_chain)
81-
for entity_mention in entity_cluster.mentions:
82-
print(entity_mention.mention_id)
83-
print(entity_mention.tokens_str)
90+
def run_cdc_pipeline():
91+
cdc_settings = create_example_settings()
92+
event_clusters, entity_clusters = run_example(cdc_settings)
93+
94+
print('-=Cross Document Coref Results=-')
95+
print_results(event_clusters, 'Event')
96+
print('################################')
97+
print_results(entity_clusters, 'Entity')
8498

8599

86100
if __name__ == '__main__':
87101
logging.basicConfig(level=logging.INFO)
88102
logger = logging.getLogger(__name__)
89103

90-
run_example()
104+
run_cdc_pipeline()

nlp_architect/common/cdc/cluster.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,14 +65,15 @@ def get_cluster_id(self) -> str:
6565
class Clusters(object):
6666
cluster_coref_chain = 1000
6767

68-
def __init__(self, mentions: List[MentionData] = None) -> None:
68+
def __init__(self, topic_id: str, mentions: List[MentionData] = None) -> None:
6969
"""
7070
7171
Args:
7272
mentions: ``list[MentionData]``, required
7373
The initial mentions to create the clusters from
7474
"""
7575
self.clusters_list = []
76+
self.topic_id = topic_id
7677
self.set_initial_clusters(mentions)
7778

7879
def set_initial_clusters(self, mentions: List[MentionData]) -> None:

nlp_architect/common/cdc/topics.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,9 @@ def __init__(self, mentions_file_path: str) -> None:
3737
Args:
3838
mentions_file_path: this topic mentions json file
3939
"""
40-
self.topics_list = self.load_gold_mentions(mentions_file_path)
40+
self.topics_list = self.load_gold_mentions_from_file(mentions_file_path)
4141

42-
def load_gold_mentions(self, mentions_file_path: str) -> List[Topic]:
42+
def load_gold_mentions_from_file(self, mentions_file_path: str) -> List[Topic]:
4343
start_data_load = time.time()
4444
logger.info('Loading mentions from-%s', mentions_file_path)
4545
mentions = load_json_file(mentions_file_path)

nlp_architect/data/cdc_resources/wikipedia/wiki_online.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
# ******************************************************************************
1616

1717
import os
18+
import logging
1819

1920
from nlp_architect.data.cdc_resources.data_types.wiki.wikipedia_page import WikipediaPage
2021
from nlp_architect.data.cdc_resources.data_types.wiki.wikipedia_page_extracted_relations import \
@@ -27,6 +28,8 @@
2728
DISAMBIGUATE_PAGE = ['wikimedia disambiguation page', 'wikipedia disambiguation page']
2829
NAME_DESCRIPTIONS = ['given name', 'first name', 'family name']
2930

31+
logger = logging.getLogger(__name__)
32+
3033

3134
class WikiOnline(object):
3235
def __init__(self):
@@ -52,7 +55,7 @@ def get_pages(self, phrase):
5255
full_page = self.get_wiki_page_with_items(phrase, page_result)
5356
ret_pages.add(WikipediaSearchPageResult(appr, full_page))
5457
except Exception as e:
55-
print(e)
58+
logger.error(e)
5659

5760
self.cache[phrase] = ret_pages
5861
return ret_pages
@@ -73,7 +76,7 @@ def get_wiki_page_with_items(self, phrase, page):
7376

7477
ret_page = WikipediaPage(phrase, None, page_title, None, 0, pageid, description, relations)
7578

76-
print('Page:' + str(ret_page) + ". Extracted successfully")
79+
logger.debug('Page:' + str(ret_page) + ". Extracted successfully")
7780

7881
return ret_page
7982

nlp_architect/models/cross_doc_coref/cdc_config.py

Lines changed: 5 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
# ******************************************************************************
1616
from typing import List, Tuple
1717

18-
from nlp_architect import LIBRARY_ROOT
1918
from nlp_architect.data.cdc_resources.relations.relation_types_enums import RelationType
2019
from nlp_architect.models.cross_doc_coref.system.sieves.sieves import SieveType
2120

@@ -26,20 +25,19 @@ def __init__(self):
2625

2726
self.__sieves_order = None
2827
self.__run_evaluation = False
29-
self.__gold_mentions_file = None
3028

3129
@property
3230
def sieves_order(self):
3331
"""
3432
Sieve definition and Sieve running order
3533
36-
Tuple[SieveType, RelationType, Threshold(float)] - define sieves to run, were
34+
Tuple[SieveType, RelationType, Threshold(float)] - define sieves to run, were
3735
38-
Strict- Merge clusters only in case all mentions has current relation between them,
39-
Relax- Merge clusters in case (matched mentions) / len(cluster_1.mentions)) >= thresh,
40-
Very_Relax- Merge clusters in case (matched mentions) / (all possible pairs) >= thresh
36+
Strict- Merge clusters only in case all mentions has current relation between them,
37+
Relax- Merge clusters in case (matched mentions) / len(cluster_1.mentions)) >= thresh,
38+
Very_Relax- Merge clusters in case (matched mentions) / (all possible pairs) >= thresh
4139
42-
RelationType represent the type of sieve to run.
40+
RelationType represent the type of sieve to run.
4341
4442
"""
4543
return self.__sieves_order
@@ -57,15 +55,6 @@ def run_evaluation(self):
5755
def run_evaluation(self, run_evaluation: bool):
5856
self.__run_evaluation = run_evaluation
5957

60-
@property
61-
def gold_mentions_file(self):
62-
"""Mentions file to run against"""
63-
return self.__gold_mentions_file
64-
65-
@gold_mentions_file.setter
66-
def gold_mentions_file(self, gold_file):
67-
self.__gold_mentions_file = gold_file
68-
6958

7059
class EventConfig(CDCConfig):
7160
def __init__(self):
@@ -91,9 +80,6 @@ def __init__(self):
9180
(SieveType.STRICT, RelationType.WORDNET_DERIVATIONALLY, 0.0)
9281
]
9382

94-
self.gold_mentions_file = LIBRARY_ROOT + \
95-
'/datasets/ecb/ecb_all_event_mentions.json'
96-
9783

9884
class EntityConfig(CDCConfig):
9985
def __init__(self):
@@ -118,6 +104,3 @@ def __init__(self):
118104
(SieveType.STRICT, RelationType.WORDNET_SAME_SYNSET_ENTITY, 0.0),
119105
(SieveType.VERY_RELAX, RelationType.REFERENT_DICT, 0.5)
120106
]
121-
122-
self.gold_mentions_file = LIBRARY_ROOT + \
123-
'/datasets/ecb/ecb_all_entity_mentions.json'

nlp_architect/models/cross_doc_coref/system/cdc_settings.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
import logging
1818

19-
from nlp_architect.common.cdc.topics import Topics
2019
from nlp_architect.data.cdc_resources.relations.computed_relation_extraction import \
2120
ComputedRelationExtraction
2221
from nlp_architect.data.cdc_resources.relations.referent_dict_relation_extraction import \
@@ -44,21 +43,12 @@ def __init__(self, resources, event_coref_config, entity_coref_config):
4443
self.context2vec_model = None
4544
self.wordnet = None
4645
self.within_doc = None
47-
self.events_topics = None
48-
self.entity_topics = None
4946
self.event_config = event_coref_config
5047
self.entity_config = entity_coref_config
5148
self.cdc_resources = resources
5249

5350
self.load_modules()
5451

55-
if event_coref_config.run_evaluation:
56-
self.events_topics = Topics(event_coref_config.gold_mentions_file)
57-
if entity_coref_config.run_evaluation:
58-
self.entity_topics = Topics(entity_coref_config.gold_mentions_file)
59-
if not self.events_topics and not self.entity_topics:
60-
raise Exception('No entity or events Gold topics loaded!')
61-
6252
def load_modules(self):
6353
relations = set()
6454
for sieve in self.event_config.sieves_order:

nlp_architect/models/cross_doc_coref/system/sieves/run_sieve_system.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
import time
1919

2020
from nlp_architect.common.cdc.cluster import Clusters
21+
from nlp_architect.common.cdc.topics import Topic
22+
from nlp_architect.models.cross_doc_coref.system.cdc_settings import CDCSettings
2123
from nlp_architect.models.cross_doc_coref.system.sieves.sieves import get_sieve
2224

2325
logger = logging.getLogger(__name__)
@@ -29,7 +31,7 @@ def __init__(self, topic):
2931
self.results_dict = dict()
3032
self.results_ordered = []
3133
logger.info('loading topic %s, total mentions: %d', topic.topic_id, len(topic.mentions))
32-
self.clusters = Clusters(topic.mentions)
34+
self.clusters = Clusters(topic.topic_id, topic.mentions)
3335

3436
@staticmethod
3537
def set_sieves_from_config(config, get_rel_extraction):
@@ -90,3 +92,12 @@ def __init__(self, topic, resources):
9092
super(RunSystemsEvent, self).__init__(topic)
9193
self.sieves = self.set_sieves_from_config(resources.event_config,
9294
resources.get_module_from_relation)
95+
96+
97+
def get_run_system(topic: Topic, resource: CDCSettings, eval_type: str):
98+
if eval_type.lower() == 'entity':
99+
return RunSystemsEntity(topic, resource)
100+
elif eval_type.lower() == 'event':
101+
return RunSystemsEvent(topic, resource)
102+
else:
103+
raise AttributeError(eval_type + ' Not supported!')

0 commit comments

Comments
 (0)