Skip to content

Commit 157ae6f

Browse files
thomaslim6793bgyori
authored andcommitted
updated indra_bert api to just use the json indra statement output instead of the more verbose and raw output for parsing and converting into indra statement objects
1 parent 67cf028 commit 157ae6f

File tree

3 files changed

+48
-90
lines changed

3 files changed

+48
-90
lines changed

indra/resources/default_belief_probs.json

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,8 @@
3636
"acsn": 0.01,
3737
"semrep": 0.05,
3838
"wormbase": 0.01,
39-
"indra_bert": 0.05
39+
"indra_bert": 0.05,
40+
"indra_gpt": 0.05
4041
},
4142
"rand": {
4243
"eidos": 0.3,
@@ -75,6 +76,7 @@
7576
"acsn": 0.1,
7677
"semrep": 0.3,
7778
"wormbase": 0.1,
78-
"indra_bert": 0.3
79+
"indra_bert": 0.3,
80+
"indra_gpt": 0.3
7981
}
8082
}

indra/sources/indra_bert/api.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,15 @@ def create_extractor(
2020
ner_model_path="thomaslim6793/indra_bert_ner_agent_detection",
2121
stmt_model_path="thomaslim6793/indra_bert_indra_stmt_classifier",
2222
role_model_path="thomaslim6793/indra_bert_indra_stmt_agents_role_assigner",
23+
mutations_model_path="thomaslim6793/indra_bert_agent_mutation_detection",
2324
stmt_conf_threshold=0.95
2425
):
2526
try:
2627
ise = IndraStructuredExtractor(
2728
ner_model_path=ner_model_path,
2829
stmt_model_path=stmt_model_path,
29-
role_model_path=role_model_path,
30+
role_model_path=role_model_path,
31+
mutations_model_path=mutations_model_path,
3032
stmt_conf_threshold=stmt_conf_threshold
3133
)
3234
except Exception as e:
@@ -36,33 +38,38 @@ def create_extractor(
3638
ner_model_path="thomaslim6793/indra_bert_ner_agent_detection",
3739
stmt_model_path="thomaslim6793/indra_bert_indra_stmt_classifier",
3840
role_model_path="thomaslim6793/indra_bert_indra_stmt_agents_role_assigner",
41+
mutations_model_path="thomaslim6793/indra_bert_agent_mutation_detection",
3942
stmt_conf_threshold=stmt_conf_threshold
4043
)
4144
logger.info(f"Loaded ner_model from: {ise.ner_model_local_path}")
4245
logger.info(f"Loaded stmt_model from: {ise.stmt_model_local_path}")
4346
logger.info(f"Loaded role_model from: {ise.role_model_local_path}")
47+
logger.info(f"Loaded mutations_model from: {ise.mutations_model_local_path}")
4448
return ise
4549

4650
def process_text(text,
4751
ner_model_path="thomaslim6793/indra_bert_ner_agent_detection",
4852
stmt_model_path="thomaslim6793/indra_bert_indra_stmt_classifier",
4953
role_model_path="thomaslim6793/indra_bert_indra_stmt_agents_role_assigner",
54+
mutations_model_path="thomaslim6793/indra_bert_agent_mutation_detection",
5055
stmt_conf_threshold=0.95,
5156
grounder=None):
5257
ise = create_extractor(
5358
ner_model_path=ner_model_path,
5459
stmt_model_path=stmt_model_path,
55-
role_model_path=role_model_path,
60+
role_model_path=role_model_path,
61+
mutations_model_path=mutations_model_path,
5662
stmt_conf_threshold=stmt_conf_threshold
5763
)
58-
res = ise.extract_structured_statements_batch(text)
64+
res = ise.get_json_indra_stmts(text)
5965
ip = IndraBertProcessor(res, grounder=grounder)
6066
return ip, ise
6167

6268
def process_texts(texts,
6369
ner_model_path="thomaslim6793/indra_bert_ner_agent_detection",
6470
stmt_model_path="thomaslim6793/indra_bert_indra_stmt_classifier",
6571
role_model_path="thomaslim6793/indra_bert_indra_stmt_agents_role_assigner",
72+
mutations_model_path="thomaslim6793/indra_bert_agent_mutation_detection",
6673
stmt_conf_threshold=0.95,
6774
grounder=None):
6875

@@ -72,13 +79,14 @@ def process_texts(texts,
7279
ise = create_extractor(
7380
ner_model_path=ner_model_path,
7481
stmt_model_path=stmt_model_path,
75-
role_model_path=role_model_path,
82+
role_model_path=role_model_path,
83+
mutations_model_path=mutations_model_path,
7684
stmt_conf_threshold=stmt_conf_threshold
7785
)
7886

7987
ips = []
8088
for text in tqdm(texts, desc="Processing texts"):
81-
res = ise.extract_structured_statements_batch(text)
89+
res = ise.get_json_indra_stmts(text)
8290
ip = IndraBertProcessor(res, grounder=grounder)
8391
ips.append(ip)
8492
return ips, ise

indra/sources/indra_bert/processor.py

Lines changed: 31 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from indra.statements import *
2+
from indra.statements.io import stmt_from_json
23
from indra.ontology.standardize import standardize_agent_name
34

45
import re
@@ -14,92 +15,39 @@ def __init__(self, data, grounder=None):
1415
self.grounder = grounder if grounder else default_grounder_wrapper
1516
self.extract_statements()
1617

17-
def get_agent(self, agent_info, context=None):
18-
name = agent_info['text']
19-
db_refs = self.grounder(name, context)
20-
db_refs['TEXT'] = name
21-
agent = Agent(name, db_refs=db_refs)
22-
standardize_agent_name(agent, standardize_refs=True)
23-
return agent
2418

2519
def extract_statement(self, entry):
26-
stmt_type = entry['stmt_pred']['label']
27-
roles = entry['role_pred']['roles']
28-
text = entry['original_text']
29-
30-
agents_by_role = {}
31-
raw_texts = {}
32-
coords = {}
33-
for agent_info in roles:
34-
role = agent_info['role']
35-
agents_by_role[role] = self.get_agent(agent_info, text)
36-
raw_texts[role] = agent_info['text']
37-
coords[role] = ([agent_info['start'], agent_info['end']])
38-
39-
evidence = Evidence(
40-
source_api=self.source_api,
41-
text=text,
42-
)
43-
44-
stmt_class = get_statement_by_name(stmt_type)
45-
if issubclass(stmt_class, Complex):
46-
if len(agents_by_role) < 2:
47-
raise ValueError("Expected at least two roles: 'members'",
48-
f" but got {agents_by_role.keys()}")
49-
for role, _ in agents_by_role.items():
50-
if not re.match(r'members\.\d+', role):
51-
raise ValueError(f"Unexpected role '{role}' for members")
52-
53-
members = [agent for role, agent in agents_by_role.items()]
54-
raw_texts = [raw_text for role, raw_text in raw_texts.items()]
55-
coords = [coord for role, coord in coords.items()]
56-
annotations = {
57-
'agents': {
58-
'raw_text': raw_texts,
59-
'coords': coords
60-
}
61-
}
62-
evidence.annotations = annotations
63-
stmt = Complex(members, evidence=[evidence])
64-
return stmt
65-
elif issubclass(stmt_class, (RegulateAmount, RegulateActivity)):
66-
if agents_by_role.keys() != {'subj', 'obj'} or len(agents_by_role) != 2:
67-
raise ValueError("Expected exactly two roles: 'subj' and 'obj'",
68-
f" but got {agents_by_role.keys()}")
69-
70-
subj = agents_by_role.get('subj')
71-
obj = agents_by_role.get('obj')
72-
raw_texts = [raw_texts.get('subj'), raw_texts.get('obj')]
73-
coords = [coords.get('subj'), coords.get('obj')]
74-
annotations = {
75-
'agents': {
76-
'raw_text': raw_texts,
77-
'coords': coords
78-
}
79-
}
80-
evidence.annotations = annotations
81-
stmt = stmt_class(subj, obj, evidence=[evidence])
82-
return stmt
83-
elif issubclass(stmt_class, Modification):
84-
if agents_by_role.keys() != {'enz', 'sub'} or len(agents_by_role) != 2:
85-
raise ValueError("Expected exactly two roles: 'enz' and 'sub'",
86-
f" but got {agents_by_role.keys()}")
87-
88-
enz = agents_by_role.get('enz')
89-
sub = agents_by_role.get('sub')
90-
raw_texts = [raw_texts.get('enz'), raw_texts.get('sub')]
91-
coords = [coords.get('enz'), coords.get('sub')]
92-
annotations = {
93-
'agents': {
94-
'raw_text': raw_texts,
95-
'coords': coords
96-
}
97-
}
98-
evidence.annotations = annotations
99-
stmt = stmt_class(enz, sub, evidence=[evidence])
20+
"""Extract a statement from JSON using INDRA's built-in functionality."""
21+
try:
22+
# Use INDRA's built-in statement_from_json functionality
23+
stmt = stmt_from_json(entry)
24+
25+
# Apply grounding to agents if grounder is available
26+
if self.grounder:
27+
text = entry['evidence'][0]['text'] if entry.get('evidence') else ""
28+
self._apply_grounding(stmt, text)
29+
10030
return stmt
101-
else:
102-
assert False, "Unsupported statement type: %s" % stmt_class
31+
32+
except Exception as e:
33+
logger.warning(f"Error creating statement from JSON: {e}")
34+
raise
35+
36+
def _apply_grounding(self, stmt, context_text):
37+
"""Apply grounding to all agents in a statement."""
38+
# Get all agents from the statement
39+
agents = stmt.agent_list()
40+
41+
for agent in agents:
42+
if agent and agent.name:
43+
# Apply grounding
44+
grounding_result = self.grounder(agent.name, context_text)
45+
if grounding_result:
46+
# Update db_refs with grounding results
47+
agent.db_refs.update(grounding_result)
48+
49+
# Standardize the agent name
50+
standardize_agent_name(agent, standardize_refs=True)
10351

10452
def extract_statements(self):
10553
self.statements = []

0 commit comments

Comments
 (0)