diff --git a/toolkit/components/ml/content/backends/OpenAIPipeline.mjs b/toolkit/components/ml/content/backends/OpenAIPipeline.mjs index cce99d8282365..326c3a9e049a1 100644 --- a/toolkit/components/ml/content/backends/OpenAIPipeline.mjs +++ b/toolkit/components/ml/content/backends/OpenAIPipeline.mjs @@ -26,6 +26,8 @@ let _logLevel = "Error"; */ const lazy = {}; +const DEFAULT_ALLOWED_OPENAI_PARAMS = Object.freeze(["tools", "tool_choice"]); + ChromeUtils.defineLazyGetter(lazy, "console", () => { return console.createInstance({ maxLogLevel: _logLevel, // we can't use maxLogLevelPref in workers. @@ -68,6 +70,9 @@ export class OpenAIPipeline { let config = {}; options.applyToConfig(config); config.backend = config.backend || "openai"; + if (!config.allowedOpenAIParams) { + config.allowedOpenAIParams = DEFAULT_ALLOWED_OPENAI_PARAMS; + } // reapply logLevel if it has changed. if (lazy.console.logLevel != config.logLevel) { @@ -334,6 +339,11 @@ export class OpenAIPipeline { }); const stream = request.streamOptions?.enabled || false; const tools = request.tools || []; + const allowedOpenAIParams = + request.allowed_openai_params ?? + request.allowedOpenAIParams ?? + this.#options.allowedOpenAIParams ?? + DEFAULT_ALLOWED_OPENAI_PARAMS; const completionParams = { model: modelId, @@ -342,6 +352,13 @@ export class OpenAIPipeline { tools, }; + if ( + Array.isArray(allowedOpenAIParams) && + allowedOpenAIParams.length > 0 + ) { + completionParams.allowed_openai_params = [...allowedOpenAIParams]; + } + const args = { client, completionParams,