Skip to content

Commit 4377732

Browse files
authored
Merge pull request #98 from GitHubSecurityLab/model_settings
add support for model parameters passing
2 parents 646538c + d0dff7f commit 4377732

File tree

3 files changed

+60
-8
lines changed

3 files changed

+60
-8
lines changed

README.md

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -457,6 +457,26 @@ taskflow:
457457

458458
The model version can then be updated by changing `gpt_latest` in the `model_config` file and applied across all taskflows that use the config.
459459

460+
In addition, model specific parameters can be provided via `model_config`. To do so, define a `model_settings` section in the `model_config` file. This section has to be a dictionary with the model names as keys:
461+
462+
```yaml
463+
model_settings:
464+
gpt_latest:
465+
temperature: 1
466+
reasoning:
467+
effort: high
468+
```
469+
470+
You do not need to set parameters for all models defined in the `models` section. When parameters are not set for a model, they'll fall back to the default value. However, all the settings in this section must belong to one of the models specified in the `models` section, otherwise an error will raise:
471+
472+
```yaml
473+
model_settings:
474+
new_model:
475+
...
476+
```
477+
478+
The above will result in an error because `new_model` is not defined in `models` section. Model parameters can also be set per task, and any settings defined in a task will override the settings in the config.
479+
460480
## Passing environment variables
461481

462482
Files of types `taskflow` and `toolbox` allow environment variables to be passed using the `env` field:

doc/GRAMMAR.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,18 @@ Tasks can optionally specify which Model to use on the configured inference endp
9191

9292
Note that model identifiers may differ between OpenAI compatible endpoint providers, make sure you change your model identifier accordingly when switching providers. If not specified, a default LLM model (`gpt-4o`) is used.
9393

94+
Parameters to the model can also be specified in the task using the `model_settings` section:
95+
96+
```yaml
97+
model: gpt-5-mini
98+
model_settings:
99+
temperature: 1
100+
reasoning:
101+
effort: high
102+
```
103+
104+
If `model_settings` is absent, then the model parameters will fall back to either the default or the ones supplied in a `model_config`. However, any parameters supplied in the task will override those that are set in the `model_config`.
105+
94106
### Completion Requirement
95107

96108
Tasks can be marked as requiring completion, if a required task fails, the taskflow will abort. This defaults to false.

src/seclab_taskflow_agent/__main__.py

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ async def deploy_task_agents(available_tools: AvailableTools,
107107
exclude_from_context: bool = False,
108108
max_turns: int = DEFAULT_MAX_TURNS,
109109
model: str = DEFAULT_MODEL,
110-
model_settings: ModelSettings | None = None,
110+
model_par: dict = {},
111111
run_hooks: TaskRunHooks | None = None,
112112
agent_hooks: TaskAgentHooks | None = None):
113113

@@ -130,10 +130,11 @@ async def deploy_task_agents(available_tools: AvailableTools,
130130

131131
# https://openai.github.io/openai-agents-python/ref/model_settings/
132132
parallel_tool_calls = True if os.getenv('MODEL_PARALLEL_TOOL_CALLS') else False
133-
model_settings = ModelSettings(
134-
temperature=os.getenv('MODEL_TEMP', default=0.0),
135-
tool_choice=('auto' if toolboxes else None),
136-
parallel_tool_calls=(parallel_tool_calls if toolboxes else None))
133+
model_params = {'temperature' : os.getenv('MODEL_TEMP', default = 0.0),
134+
'tool_choice' : ('auto' if toolboxes else None),
135+
'parallel_tool_calls' : (parallel_tool_calls if toolboxes else None)}
136+
model_params.update(model_par)
137+
model_settings = ModelSettings(**model_params)
137138

138139
# block tools if requested
139140
tool_filter = create_static_tool_filter(blocked_tool_names=blocked_tools) if blocked_tools else None
@@ -438,13 +439,22 @@ async def on_handoff_hook(
438439
global_variables.update(cli_globals)
439440
model_config = taskflow.get('model_config', {})
440441
model_keys = []
442+
models_params = {}
441443
if model_config:
442-
model_dict = available_tools.get_model_config(model_config)
443-
model_dict = model_dict.get('models', {})
444+
m_config = available_tools.get_model_config(model_config)
445+
model_dict = m_config.get('models', {})
444446
if model_dict:
445447
if not isinstance(model_dict, dict):
446448
raise ValueError(f"Models section of the model_config file {model_config} must be a dictionary")
447-
model_keys = model_dict.keys()
449+
model_keys = model_dict.keys()
450+
models_params = m_config.get('model_settings', {})
451+
if models_params and not isinstance(models_params, dict):
452+
raise ValueError(f"Settings section of model_config file {model_config} must be a dictionary")
453+
if not set(models_params.keys()).difference(model_keys).issubset(set([])):
454+
raise ValueError(f"Settings section of model_config file {model_config} contains models that are not in the model section")
455+
for k,v in models_params.items():
456+
if not isinstance(v, dict):
457+
raise ValueError(f"Settings for model {k} in model_config file {model_config} is not a dictionary")
448458

449459
for task in taskflow['taskflow']:
450460

@@ -465,8 +475,17 @@ async def on_handoff_hook(
465475
if k not in task_body:
466476
task_body[k] = v
467477
model = task_body.get('model', DEFAULT_MODEL)
478+
model_settings = {}
468479
if model in model_keys:
480+
if model in models_params:
481+
model_settings = models_params[model].copy()
469482
model = model_dict[model]
483+
task_model_settings = task_body.get('model_settings', {})
484+
if not isinstance(task_model_settings, dict):
485+
name = task.get('name', '')
486+
raise ValueError(f"model_settings in task {name} needs to be a dictionary")
487+
model_settings.update(task_model_settings)
488+
470489
# parse our taskflow grammar
471490
name = task_body.get('name', 'taskflow') # placeholder, not used yet
472491
description = task_body.get('description', 'taskflow') # placeholder not used yet
@@ -622,6 +641,7 @@ async def _deploy_task_agents(resolved_agents, prompt):
622641
on_tool_end=on_tool_end_hook,
623642
on_tool_start=on_tool_start_hook),
624643
model = model,
644+
model_par = model_settings,
625645
agent_hooks=TaskAgentHooks(
626646
on_handoff=on_handoff_hook))
627647
return result

0 commit comments

Comments
 (0)