-
Notifications
You must be signed in to change notification settings - Fork 31
Fix PredictModel to return correct output fields
#55
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,6 @@ | ||
| import dspy | ||
| from dspy.adapters.chat_adapter import ChatAdapter, prepare_instructions | ||
| from cognify.llm import Model, StructuredModel, Input, OutputFormat | ||
| from cognify.llm import Model, StructuredModel, Input, OutputFormat, OutputLabel | ||
| from cognify.llm.model import LMConfig | ||
| from pydantic import BaseModel, create_model | ||
| from typing import Any, Dict, Type | ||
|
|
@@ -41,48 +41,62 @@ def cognify_predictor( | |
|
|
||
| if not isinstance(dspy_predictor, dspy.Predict): | ||
| warnings.warn( | ||
| "Original module is not a `Predict`. This may result in lossy translation", | ||
| "Original module is NOT a `dspy.Predict`. This may result in lossy translation", | ||
| UserWarning, | ||
| ) | ||
|
|
||
| if isinstance(dspy_predictor, dspy.Retrieve): | ||
| warnings.warn( | ||
| "Original module is a `Retrieve`. This will be ignored", UserWarning | ||
| "Original module is a `dspy.Retrieve`. This will be ignored", UserWarning | ||
| ) | ||
| self.ignore_module = True | ||
| return None | ||
|
|
||
| # initialize cog lm | ||
| system_prompt = prepare_instructions(dspy_predictor.signature) | ||
| input_names = list(dspy_predictor.signature.input_fields.keys()) | ||
| input_variables = [Input(name=input_name) for input_name in input_names] | ||
|
|
||
| output_fields = dspy_predictor.signature.output_fields | ||
| if "reasoning" in output_fields: | ||
| del output_fields["reasoning"] | ||
| # stripping the reasoning field may crash their workflow, so we warn users instead | ||
| warnings.warn( | ||
| "Original module contained reasoning. This will be stripped. Add reasoning as a cog instead", | ||
| f"DSPy module {name} contained reasoning. This may lead to undefined behavior.", | ||
|
||
| UserWarning, | ||
| ) | ||
| output_fields_for_schema = {k: v.annotation for k, v in output_fields.items()} | ||
| self.output_schema = generate_pydantic_model( | ||
| "OutputData", output_fields_for_schema | ||
| ) | ||
| system_prompt = prepare_instructions(dspy_predictor.signature) | ||
|
|
||
| # lm config | ||
| lm_client: dspy.LM = dspy.settings.get("lm", None) | ||
|
|
||
| assert lm_client, "Expected lm to be configured in dspy" | ||
| lm_config = LMConfig(model=lm_client.model, kwargs=lm_client.kwargs) | ||
|
|
||
| # always treat as structured to provide compatiblity with forward function | ||
| return StructuredModel( | ||
| # treat as cognify.Model, allow dspy to handle output parsing | ||
| return Model( | ||
| agent_name=name, | ||
| system_prompt=system_prompt, | ||
| input_variables=input_variables, | ||
| output_format=OutputFormat(schema=self.output_schema), | ||
| lm_config=lm_config, | ||
| output=OutputLabel("llm_output"), | ||
| lm_config=lm_config | ||
| ) | ||
|
|
||
| def construct_messages(self, inputs): | ||
| messages = None | ||
| if self.predictor: | ||
| messages: APICompatibleMessage = self.chat_adapter.format( | ||
| self.predictor.signature, self.predictor.demos, inputs | ||
| ) | ||
| return messages | ||
|
|
||
| def parse_output(self, result): | ||
| values = [] | ||
|
|
||
| # from dspy chat adapter __call__ | ||
| value = self.chat_adapter.parse(self.predictor.signature, result, _parse_values=True) | ||
| assert set(value.keys()) == set(self.predictor.signature.output_fields.keys()), f"Expected {self.predictor.signature.output_fields.keys()} but got {value.keys()}" | ||
| values.append(value) | ||
|
|
||
| return values | ||
|
|
||
| def forward(self, **kwargs): | ||
| assert ( | ||
|
|
@@ -95,19 +109,12 @@ def forward(self, **kwargs): | |
| inputs: Dict[str, str] = { | ||
| k.name: kwargs[k.name] for k in self.cog_lm.input_variables | ||
| } | ||
| messages = None | ||
| if self.predictor: | ||
| messages: APICompatibleMessage = self.chat_adapter.format( | ||
| self.predictor.signature, self.predictor.demos, inputs | ||
| ) | ||
| messages = self.construct_messages(inputs) | ||
| result = self.cog_lm( | ||
| messages, inputs | ||
| ) # kwargs have already been set when initializing cog_lm | ||
| kwargs: dict = result.model_dump() | ||
| for k,v in kwargs.items(): | ||
| if not v: | ||
| raise ValueError(f"{self.cog_lm.name} did not generate a value for field `{k}`, consider using a larger model for structured output") | ||
| return dspy.Prediction(**kwargs) | ||
| completions = self.parse_output(result) | ||
| return dspy.Prediction.from_completions(completions, signature=self.predictor.signature) | ||
|
|
||
|
|
||
| def as_predict(cog_lm: Model) -> PredictModel: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what does this retrive mean? why ignore?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dspy.Retrieveis their retriever. They still consider retrieval to be a "module", and since we don't optimize retrieval, I just ignore it. This means whenever it gets called in the actual workflow, we will call the original retrieve module.