diff --git a/pywhyllm/suggesters/identification_suggester.py b/pywhyllm/suggesters/identification_suggester.py index 5cd0d15..f7d7948 100644 --- a/pywhyllm/suggesters/identification_suggester.py +++ b/pywhyllm/suggesters/identification_suggester.py @@ -11,9 +11,14 @@ class IdentificationSuggester(IdentifierProtocol): def __init__(self, llm=None): if llm is not None: - if (llm == 'gpt-4'): + if llm == 'gpt-4': self.llm = guidance.models.OpenAI('gpt-4') self.model_suggester = ModelSuggester('gpt-4') + elif isinstance(llm, guidance.models.Model): + self.llm = llm + self.model_suggester = ModelSuggester(llm) + else: + raise ValueError("llm must be either 'gpt-4' or a guidance model instance.") # def suggest_estimand( # self, diff --git a/pywhyllm/suggesters/model_suggester.py b/pywhyllm/suggesters/model_suggester.py index b7b9df2..11b7642 100644 --- a/pywhyllm/suggesters/model_suggester.py +++ b/pywhyllm/suggesters/model_suggester.py @@ -12,8 +12,12 @@ class ModelSuggester(ModelerProtocol): def __init__(self, llm=None): if llm is not None: - if (llm == 'gpt-4'): + if llm == 'gpt-4': self.llm = guidance.models.OpenAI('gpt-4') + elif isinstance(llm, guidance.models.Model): + self.llm = llm + else: + raise ValueError("llm must be either 'gpt-4' or a guidance model instance.") def suggest_domain_expertises( self, diff --git a/pywhyllm/suggesters/simple_identification_suggester.py b/pywhyllm/suggesters/simple_identification_suggester.py index cc404a9..36369eb 100644 --- a/pywhyllm/suggesters/simple_identification_suggester.py +++ b/pywhyllm/suggesters/simple_identification_suggester.py @@ -7,8 +7,12 @@ class SimpleIdentificationSuggester: def __init__(self, llm=None): if llm is not None: - if (llm == 'gpt-4'): + if llm == 'gpt-4': self.llm = guidance.models.OpenAI('gpt-4') + elif isinstance(llm, guidance.models.Model): + self.llm = llm + else: + raise ValueError("llm must be either 'gpt-4' or a guidance model instance.") def suggest_iv(self, factors, treatment, outcome): lm = self.llm diff --git a/pywhyllm/suggesters/simple_model_suggester.py b/pywhyllm/suggesters/simple_model_suggester.py index d2b130a..26d389f 100644 --- a/pywhyllm/suggesters/simple_model_suggester.py +++ b/pywhyllm/suggesters/simple_model_suggester.py @@ -22,8 +22,12 @@ class SimpleModelSuggester: def __init__(self, llm=None): if llm is not None: - if (llm == 'gpt-4'): + if llm == 'gpt-4': self.llm = guidance.models.OpenAI('gpt-4') + elif isinstance(llm, guidance.models.Model): + self.llm = llm + else: + raise ValueError("llm must be either 'gpt-4' or a guidance model instance.") # new ver def suggest_pairwise_relationship(self, variable1: str, variable2: str): diff --git a/pywhyllm/suggesters/validation_suggester.py b/pywhyllm/suggesters/validation_suggester.py index dcd0cc5..02904a6 100644 --- a/pywhyllm/suggesters/validation_suggester.py +++ b/pywhyllm/suggesters/validation_suggester.py @@ -18,6 +18,10 @@ def __init__(self, llm=None): if llm is not None: if llm == 'gpt-4': self.llm = guidance.models.OpenAI('gpt-4') + elif isinstance(llm, guidance.models.Model): + self.llm = llm + else: + raise ValueError("llm must be either 'gpt-4' or a guidance model instance.") def suggest_negative_controls( self,