Skip to content

Commit a71bf37

Browse files
committed
revert generator
1 parent 839d018 commit a71bf37

File tree

1 file changed

+14
-1
lines changed

1 file changed

+14
-1
lines changed

adalflow/adalflow/core/generator.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,11 +70,22 @@ class Generator(GradComponent, CachedEngine, CallbackManager):
7070
template (Optional[str], optional): The template for the prompt. Defaults to :ref:`DEFAULT_ADALFLOW_SYSTEM_PROMPT<core-default_prompt_template>`.
7171
prompt_kwargs (Optional[Dict], optional): The preset prompt kwargs to fill in the variables in the prompt. Defaults to None.
7272
output_processors (Optional[Component], optional): The output processors after model call. It can be a single component or a chained component via ``Sequential``. Defaults to None.
73+
trainable_params (Optional[List[str]], optional): The list of trainable parameters. Defaults to [].
74+
Note:
75+
The output_processors will be applied to the string output of the model completion. And the result will be stored in the data field of the output.
76+
And we encourage you to only use it to parse the response to data format you will use later.
7377
name (Optional[str], optional): The name of the generator. Defaults to None.
7478
cache_path (Optional[str], optional): The path to save the cache. Defaults to None.
7579
use_cache (bool, optional): Whether to use cache. Defaults to False.
7680
"""
7781

82+
model_type: ModelType = ModelType.LLM
83+
model_client: ModelClient # for better type checking
84+
_use_cache: bool = False
85+
_kwargs: Dict[str, Any] = (
86+
{}
87+
) # to create teacher generator from student TODO: might reaccess this
88+
7889
def __init__(
7990
self,
8091
*,
@@ -90,6 +101,7 @@ def __init__(
90101
# args for the cache
91102
cache_path: Optional[str] = None,
92103
use_cache: bool = False,
104+
trainable_params: Optional[List[str]] = [],
93105
) -> None:
94106
r"""The default prompt is set to the DEFAULT_ADALFLOW_SYSTEM_PROMPT. It has the following variables:
95107
- task_desc_str
@@ -122,7 +134,6 @@ def __init__(
122134
CallbackManager.__init__(self)
123135

124136
self.name = name or self.__class__.__name__
125-
self.model_type = model_client.model_type # Get model type from client
126137

127138
self._init_prompt(template, prompt_kwargs)
128139

@@ -153,6 +164,7 @@ def __init__(
153164
"name": name,
154165
"cache_path": cache_path,
155166
"use_cache": use_cache,
167+
"trainable_params": trainable_params,
156168
}
157169
self._teacher: Optional["Generator"] = None
158170
self._trace_api_kwargs: Dict[str, Any] = (
@@ -326,6 +338,7 @@ def _pre_call(self, prompt_kwargs: Dict, model_kwargs: Dict) -> Dict[str, Any]:
326338
api_kwargs = self.model_client.convert_inputs_to_api_kwargs(
327339
input=prompt_str,
328340
model_kwargs=composed_model_kwargs,
341+
model_type=self.model_type,
329342
)
330343
return api_kwargs
331344

0 commit comments

Comments
 (0)