Skip to content

Commit cdaf1a0

Browse files
committed
Fixed and tested all methods in examples notebook and moved Tuebingen model suggester class.
1 parent fb5bcfe commit cdaf1a0

File tree

4 files changed

+641
-606
lines changed

4 files changed

+641
-606
lines changed

pywhyllm/suggesters/identification_suggester.py

Lines changed: 143 additions & 121 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,25 @@
44
from .model_suggester import ModelSuggester
55
from ..prompts import prompts as ps
66
import guidance
7+
from guidance import system, user, assistant, gen
78
import re
89

10+
911
# from dowhy import causal_identifier as ci
1012

1113

1214
class IdentificationSuggester(IdentifierProtocol):
13-
EXPERTS: list() = [
14-
"cause and effect",
15-
"causality, you are an intelligent AI with expertise in causality",
16-
]
15+
# EXPERTS: list() = [
16+
# "cause and effect",
17+
# "causality, you are an intelligent AI with expertise in causality",
18+
# ]
1719
CONTEXT: str = """causal mechanisms"""
1820

21+
def __init__(self, llm):
22+
if llm == 'gpt-4':
23+
self.llm = guidance.models.OpenAI('gpt-4')
24+
self.model_suggester = ModelSuggester('gpt-4')
25+
1926
# def suggest_estimand(
2027
# self,
2128
# treatment: str,
@@ -114,41 +121,44 @@ class IdentificationSuggester(IdentifierProtocol):
114121
# return estimand
115122

116123
def suggest_backdoor(
117-
self,
118-
treatment: str,
119-
outcome: str,
120-
factors_list: list(),
121-
llm: guidance.models,
122-
experts: list() = EXPERTS,
123-
analysis_context: list() = CONTEXT,
124-
stakeholders: list() = None,
125-
temperature=0.3,
126-
model_type: ModelType = ModelType.Completion,
124+
self,
125+
treatment: str,
126+
outcome: str,
127+
factors_list: list(),
128+
expertise_list: list(),
129+
analysis_context: list() = CONTEXT,
130+
stakeholders: list() = None
127131
):
128-
backdoor_set = ModelSuggester.suggest_confounders(
129-
analysis_context=analysis_context,
132+
backdoor_set = self.model_suggester.suggest_confounders(
130133
treatment=treatment,
131134
outcome=outcome,
132135
factors_list=factors_list,
133-
experts=experts,
134-
llm=llm,
135-
stakeholders=stakeholders,
136-
temperature=temperature,
137-
model_type=model_type,
136+
expertise_list=expertise_list,
137+
analysis_context=analysis_context,
138+
stakeholders=stakeholders
138139
)
139140
return backdoor_set
140141

142+
#TODO:implement
143+
def suggest_frontdoor(
144+
self,
145+
treatment: str,
146+
outcome: str,
147+
factors_list: list(),
148+
expertise_list: list(),
149+
analysis_context: list() = CONTEXT,
150+
stakeholders: list() = None
151+
):
152+
pass
153+
141154
def suggest_mediators(
142-
self,
143-
treatment: str,
144-
outcome: str,
145-
factors_list: list(),
146-
llm: guidance.models,
147-
experts: list() = EXPERTS,
148-
analysis_context: list() = CONTEXT,
149-
stakeholders: list() = None,
150-
temperature=0.3,
151-
model_type: ModelType = ModelType.Completion,
155+
self,
156+
treatment: str,
157+
outcome: str,
158+
factors_list: list(),
159+
experts: list(),
160+
analysis_context: list() = CONTEXT,
161+
stakeholders: list() = None
152162
):
153163
expert_list: List[str] = list()
154164
for elements in experts:
@@ -157,8 +167,6 @@ def suggest_mediators(
157167
for elements in stakeholders:
158168
expert_list.append(elements)
159169

160-
suggest = guidance(ps[model_type.value]["expert_suggests_mediators"])
161-
162170
mediators: List[str] = list()
163171
mediators_edges: Dict[Tuple[str, str], int] = dict()
164172
mediators_edges[(treatment, outcome)] = 1
@@ -171,29 +179,23 @@ def suggest_mediators(
171179
if len(expert_list) > 1:
172180
for expert in expert_list:
173181
mediators_edges, mediators_list = self.request_mediators(
174-
suggest=suggest,
175182
treatment=treatment,
176183
outcome=outcome,
177184
analysis_context=analysis_context,
178-
expert=expert,
179-
edited_factors_list=edited_factors_list,
180-
temperature=temperature,
181-
llm=llm,
182-
mediators_edges=mediators_edges,
185+
domain_expertise=expert,
186+
factors_list=edited_factors_list,
187+
mediators_edges=mediators_edges
183188
)
184189
for m in mediators_list:
185190
if m not in mediators:
186191
mediators.append(m)
187192
else:
188193
mediators_edges, mediators_list = self.request_mediators(
189-
suggest=suggest,
190194
treatment=treatment,
191195
outcome=outcome,
192196
analysis_context=analysis_context,
193-
expert=expert_list[0],
194-
edited_factors_list=edited_factors_list,
195-
temperature=temperature,
196-
llm=llm,
197+
domain_expertise=expert_list[0],
198+
factors_list=edited_factors_list,
197199
mediators_edges=mediators_edges,
198200
)
199201

@@ -204,46 +206,62 @@ def suggest_mediators(
204206
return mediators_edges, mediators
205207

206208
def request_mediators(
207-
self,
208-
suggest,
209-
treatment,
210-
outcome,
211-
analysis_context,
212-
expert,
213-
edited_factors_list,
214-
temperature,
215-
llm,
216-
mediators_edges,
209+
self,
210+
treatment,
211+
outcome,
212+
analysis_context,
213+
domain_expertise,
214+
factors_list,
215+
mediators_edges
217216
):
218217
mediators: List[str] = list()
219218

220219
success: bool = False
221220

222221
while not success:
223222
try:
224-
output = suggest(
225-
treatment=treatment,
226-
outcome=outcome,
227-
analysis_context=analysis_context,
228-
domain_expertise=expert,
229-
factors_list=edited_factors_list,
230-
factor=factor,
231-
temperature=temperature,
232-
llm=llm,
233-
)
223+
lm = self.llm
224+
with system():
225+
lm += f"""You are an expert in {domain_expertise} and are studying {analysis_context}. You are using your
226+
knowledge to help build a causal model that contains all the assumptions about the factors that are directly
227+
influencing athe {outcome}. Where a causal model is a conceptual model that describes the causal mechanisms
228+
of a system. You will do this by by answering questions about cause and effect and using your domain knowledge
229+
in {domain_expertise}. Follow the next two steps, and complete the first one before moving on to the second:"""
230+
231+
with user():
232+
lm += f"""(1) From your perspective as an expert in {domain_expertise}, think step by step as you consider the factors
233+
that may interact between the {treatment} and the {outcome}. Use your knowledge as an expert in
234+
{domain_expertise} to describe the mediators, if there are any at all, between {treatment} and the
235+
{outcome}. Be concise and keep your thinking within two paragraphs. Then provide your step by step chain
236+
of thoughts within the tags <thinking></thinking>. (2) From your perspective as an expert in {domain_expertise},
237+
which factor(s) of the following factors, if any at all, has/have a high likelihood of directly influencing and
238+
causing the assignment of the {outcome} and also has/have a high likelihood of being directly influenced and
239+
caused by the assignment of the {treatment}? Which factor(s) of the following factors, if any at all, is/are
240+
on the causal chain that links the {treatment} to the {outcome}? From your perspective as an expert in
241+
{domain_expertise}, which factor(s) of the following factors, if any at all, mediates, is/are on the causal
242+
chain, that links the {treatment} to the {outcome}? Then provide your step by step chain of thoughts within
243+
the tags <thinking></thinking>. factor_names : {factors_list} Wrap the name of the factor(s), if any at all,
244+
that has/have a high likelihood of directly influencing and causing the assignment of the {outcome} and also
245+
has/have a high likelihood of being directly influenced and caused by the assignment of the {treatment} within
246+
the tags <mediating_factor>factor_name</mediating_factor>. Where factor_name is one of the items within the
247+
factor_names list. If a factor does not have a high likelihood of mediating, then do not wrap the factor with
248+
any tags. Your step by step answer as an in {domain_expertise}:"""
249+
250+
with assistant():
251+
lm += gen("output")
252+
253+
output = lm["output"]
254+
234255
mediating_factor = re.findall(
235-
r"<mediating_factor>(.*?)</mediating_factor>",
236-
output["output"],
237-
)
256+
r"<mediating_factor>(.*?)</mediating_factor>", output)
238257

239258
if mediating_factor:
240259
for factor in mediating_factor:
241260
# to not add it twice into the list
242-
if factor in edited_factors_list and factor not in mediators:
261+
if factor in factors_list and factor not in mediators:
243262
mediators.append(factor)
244263
success = True
245264
else:
246-
llm.OpenAI.cache.clear()
247265
success = False
248266

249267
except KeyError:
@@ -252,8 +270,8 @@ def request_mediators(
252270

253271
for element in mediators:
254272
if (treatment, element) in mediators_edges and (
255-
element,
256-
outcome,
273+
element,
274+
outcome,
257275
) in mediators_edges:
258276
mediators_edges[(treatment, element)] += 1
259277
mediators_edges[(element, outcome)] += 1
@@ -264,26 +282,21 @@ def request_mediators(
264282
return mediators_edges, mediators
265283

266284
def suggest_ivs(
267-
self,
268-
treatment: str,
269-
outcome: str,
270-
factors_list: list(),
271-
llm: guidance.models,
272-
experts: list() = EXPERTS,
273-
analysis_context: list() = CONTEXT,
274-
stakeholders: list() = None,
275-
temperature=0.3,
276-
model_type: ModelType = ModelType.Completion,
285+
self,
286+
treatment: str,
287+
outcome: str,
288+
factors_list: list(),
289+
expertise_list: list(),
290+
analysis_context: list() = CONTEXT,
291+
stakeholders: list() = None
277292
):
278293
expert_list: List[str] = list()
279-
for elements in experts:
294+
for elements in expertise_list:
280295
expert_list.append(elements)
281296
if stakeholders is not None:
282297
for elements in stakeholders:
283298
expert_list.append(elements)
284299

285-
suggest = guidance(ps[model_type.value]["expert_suggests_mediators"])
286-
287300
ivs: List[str] = list()
288301
iv_edges: Dict[Tuple[str, str], int] = dict()
289302
iv_edges[(treatment, outcome)] = 1
@@ -296,72 +309,81 @@ def suggest_ivs(
296309
if len(expert_list) > 1:
297310
for expert in expert_list:
298311
self.request_ivs(
299-
suggest=suggest,
300312
treatment=treatment,
301313
outcome=outcome,
302314
analysis_context=analysis_context,
303-
expert=expert,
304-
edited_factors_list=edited_factors_list,
305-
temperature=temperature,
306-
llm=llm,
315+
domain_expertise=expert,
316+
factors_list=edited_factors_list,
307317
iv_edges=iv_edges,
308318
)
309319
else:
310320
self.request_ivs(
311-
suggest=suggest,
312321
treatment=treatment,
313322
outcome=outcome,
314323
analysis_context=analysis_context,
315-
expert=expert_list[0],
316-
edited_factors_list=edited_factors_list,
317-
temperature=temperature,
318-
llm=llm,
324+
domain_expertise=expert_list[0],
325+
factors_list=edited_factors_list,
319326
iv_edges=iv_edges,
320327
)
321328

322329
return iv_edges, ivs
323330

324331
def request_ivs(
325-
self,
326-
suggest,
327-
treatment,
328-
outcome,
329-
analysis_context,
330-
expert,
331-
edited_factors_list,
332-
temperature,
333-
llm,
334-
iv_edges,
332+
self,
333+
treatment,
334+
outcome,
335+
analysis_context,
336+
domain_expertise,
337+
factors_list,
338+
iv_edges
335339
):
336340
ivs: List[str] = list()
337341

338342
success: bool = False
339343

340344
while not success:
341345
try:
342-
output = suggest(
343-
treatment=treatment,
344-
outcome=outcome,
345-
analysis_context=analysis_context,
346-
domain_expertise=expert,
347-
factors_list=edited_factors_list,
348-
factor=factor,
349-
temperature=temperature,
350-
llm=llm,
351-
)
352-
iv_factors = re.findall(
353-
r"<iv_factor>(.*?)</iv_factor>",
354-
output["output"],
355-
)
346+
lm = self.llm
347+
with system():
348+
lm += f"""You are an expert in {domain_expertise} and are studying {analysis_context}.
349+
You are using your knowledge to help build a causal model that contains all the assumptions about the factors
350+
that are directly influencing the {outcome}. Where a causal model is a conceptual model that describes the
351+
causal mechanisms of a system. You will do this by by answering questions about cause and effect and using
352+
your domain knowledge in {domain_expertise}. Follow the next two steps, and complete the first one before
353+
moving on to the second:"""
354+
355+
with user():
356+
lm += f"""(1) From your perspective as an expert in {domain_expertise}, think step by step
357+
as you consider the factors that may interact with the {treatment} and do not interact with {outcome}.
358+
Use your knowlegde as an expert in {domain_expertise} to describe the instrumental variable(s),
359+
if there are any at all, that both has/have a high likelihood of influecing and causing the {treatment} and
360+
has/have a very low likelihood of influencing and causing the {outcome}. Be concise and keep your thinking
361+
within two paragraphs. Then provide your step by step chain of thoughts within the tags
362+
<thinking></thinking>. (2) From your perspective as an expert in {domain_expertise}, which factor(s) of the
363+
following factors, if there are any at all, both has/have a high likelihood of influecing and causing the {
364+
treatment} and has/have a very low likelihood of influencing and causing the {outcome}? Which factor(s) of
365+
the following factors, if any at all, has/have a causal link to the {treatment} and has not causal link to
366+
the {outcome}? Which factor(s) of the following factors, if any at all, are (an) instrumental variable(s)
367+
to the causal relationship of the {treatment} causing the {outcome}? Be concise and keep your thinking
368+
within two paragraphs. Then provide your step by step chain of thoughts within the tags
369+
<thinking></thinking>. factor_names : {factors_list} Wrap the name of the factor(s), if there are any at
370+
all, that both has/have a high likelihood of influecing and causing the {treatment} and has/have a very low
371+
likelihood of influencing and causing the {outcome}, within the tags <iv_factor>factor_name</iv_factor>.
372+
Where factor_name is one of the items within the factor_names list. If a factor does not have a high
373+
likelihood of being an instrumental variable, then do not wrap the factor with any tags. Your step by step
374+
answer as an in {domain_expertise}:"""
375+
with assistant():
376+
lm += gen("output")
377+
378+
output = lm["output"]
379+
iv_factors = re.findall(r"<iv_factor>(.*?)</iv_factor>", output)
356380

357381
if iv_factors:
358382
for factor in iv_factors:
359-
# to not add it twice into the list
360-
if factor in edited_factors_list and factor not in ivs:
383+
if factor in factors_list and factor not in ivs:
361384
ivs.append(factor)
362385
success = True
363386
else:
364-
llm.OpenAI.cache.clear()
365387
success = False
366388

367389
except KeyError:

0 commit comments

Comments
 (0)