1- from typing import List , Dict , Set , Tuple , Protocol
1+ from typing import List , Dict , Tuple
22from ..protocols import IdentifierProtocol
3- from ..helpers import RelationshipStrategy , ModelType
43from .model_suggester import ModelSuggester
5- from ..prompts import prompts as ps
64import guidance
75from guidance import system , user , assistant , gen
86import re
97
108
119class IdentificationSuggester (IdentifierProtocol ):
12-
1310 CONTEXT : str = """causal mechanisms"""
1411
15- def __init__ (self , llm ):
16- if llm == 'gpt-4' :
17- self .llm = guidance .models .OpenAI ('gpt-4' )
18- self .model_suggester = ModelSuggester ('gpt-4' )
12+ def __init__ (self , llm = None ):
13+ if llm is not None :
14+ if (llm == 'gpt-4' ):
15+ self .llm = guidance .models .OpenAI ('gpt-4' )
16+ self .model_suggester = ModelSuggester ('gpt-4' )
1917
2018 # def suggest_estimand(
2119 # self,
@@ -120,7 +118,7 @@ def suggest_backdoor(
120118 outcome : str ,
121119 factors_list : list (),
122120 expertise_list : list (),
123- analysis_context : list () = CONTEXT ,
121+ analysis_context = CONTEXT ,
124122 stakeholders : list () = None
125123 ):
126124 backdoor_set = self .model_suggester .suggest_confounders (
@@ -133,14 +131,14 @@ def suggest_backdoor(
133131 )
134132 return backdoor_set
135133
136- #TODO:implement
134+ # TODO:implement
137135 def suggest_frontdoor (
138136 self ,
139137 treatment : str ,
140138 outcome : str ,
141139 factors_list : list (),
142140 expertise_list : list (),
143- analysis_context : list () = CONTEXT ,
141+ analysis_context = CONTEXT ,
144142 stakeholders : list () = None
145143 ):
146144 pass
@@ -151,7 +149,7 @@ def suggest_mediators(
151149 outcome : str ,
152150 factors_list : list (),
153151 expertise_list : list (),
154- analysis_context : list () = CONTEXT ,
152+ analysis_context = CONTEXT ,
155153 stakeholders : list () = None
156154 ):
157155 expert_list : List [str ] = list ()
@@ -170,43 +168,28 @@ def suggest_mediators(
170168 if factors_list [i ] != treatment and factors_list [i ] != outcome :
171169 edited_factors_list .append (factors_list [i ])
172170
173- if len (expert_list ) > 1 :
174- for expert in expert_list :
175- mediators_edges , mediators_list = self .request_mediators (
176- treatment = treatment ,
177- outcome = outcome ,
178- analysis_context = analysis_context ,
179- domain_expertise = expert ,
180- factors_list = edited_factors_list ,
181- mediators_edges = mediators_edges
182- )
183- for m in mediators_list :
184- if m not in mediators :
185- mediators .append (m )
186- else :
171+ for expert in expert_list :
187172 mediators_edges , mediators_list = self .request_mediators (
188173 treatment = treatment ,
189174 outcome = outcome ,
190- analysis_context = analysis_context ,
191- domain_expertise = expert_list [0 ],
175+ domain_expertise = expert ,
192176 factors_list = edited_factors_list ,
193177 mediators_edges = mediators_edges ,
178+ analysis_context = analysis_context
194179 )
195-
196180 for m in mediators_list :
197181 if m not in mediators :
198182 mediators .append (m )
199-
200183 return mediators_edges , mediators
201184
202185 def request_mediators (
203186 self ,
204187 treatment ,
205188 outcome ,
206- analysis_context ,
207189 domain_expertise ,
208190 factors_list ,
209- mediators_edges
191+ mediators_edges ,
192+ analysis_context = CONTEXT
210193 ):
211194 mediators : List [str ] = list ()
212195
@@ -254,9 +237,7 @@ def request_mediators(
254237 # to not add it twice into the list
255238 if factor in factors_list and factor not in mediators :
256239 mediators .append (factor )
257- success = True
258- else :
259- success = False
240+ success = True
260241
261242 except KeyError :
262243 success = False
@@ -281,7 +262,7 @@ def suggest_ivs(
281262 outcome : str ,
282263 factors_list : list (),
283264 expertise_list : list (),
284- analysis_context : list () = CONTEXT ,
265+ analysis_context = CONTEXT ,
285266 stakeholders : list () = None
286267 ):
287268 expert_list : List [str ] = list ()
@@ -300,26 +281,20 @@ def suggest_ivs(
300281 if factors_list [i ] != treatment and factors_list [i ] != outcome :
301282 edited_factors_list .append (factors_list [i ])
302283
303- if len (expert_list ) > 1 :
304- for expert in expert_list :
305- self .request_ivs (
306- treatment = treatment ,
307- outcome = outcome ,
308- analysis_context = analysis_context ,
309- domain_expertise = expert ,
310- factors_list = edited_factors_list ,
311- iv_edges = iv_edges ,
312- )
313- else :
314- self .request_ivs (
284+ for expert in expert_list :
285+ iv_edges , iv_list = self .request_ivs (
315286 treatment = treatment ,
316287 outcome = outcome ,
317288 analysis_context = analysis_context ,
318- domain_expertise = expert_list [ 0 ] ,
289+ domain_expertise = expert ,
319290 factors_list = edited_factors_list ,
320291 iv_edges = iv_edges ,
321292 )
322293
294+ for m in iv_list :
295+ if m not in ivs :
296+ ivs .append (m )
297+
323298 return iv_edges , ivs
324299
325300 def request_ivs (
@@ -376,9 +351,7 @@ def request_ivs(
376351 for factor in iv_factors :
377352 if factor in factors_list and factor not in ivs :
378353 ivs .append (factor )
379- success = True
380- else :
381- success = False
354+ success = True
382355
383356 except KeyError :
384357 success = False
@@ -390,4 +363,4 @@ def request_ivs(
390363 else :
391364 iv_edges [(element , treatment )] = 1
392365
393- return iv_edges
366+ return iv_edges , ivs
0 commit comments