Skip to content

Commit 077c87b

Browse files
committed
Use max_completion_tokens for OpenAI models on OCI GenAI
OpenAI models served through OCI GenAI reject the max_tokens parameter and require max_completion_tokens instead. Detect provider=openai in the oci_langchain wrapper and use the correct key in model_kwargs.
1 parent 1ea626a commit 077c87b

File tree

2 files changed

+28
-1
lines changed

2 files changed

+28
-1
lines changed

packages/nvidia_nat_langchain/src/nat/plugins/langchain/llm.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,10 @@ async def oci_langchain(llm_config: OCIModelConfig, _builder: Builder):
244244
if llm_config.top_p is not None:
245245
model_kwargs["top_p"] = llm_config.top_p
246246
if llm_config.max_tokens is not None:
247-
model_kwargs["max_tokens"] = llm_config.max_tokens
247+
if llm_config.provider and llm_config.provider.lower() == "openai":
248+
model_kwargs["max_completion_tokens"] = llm_config.max_tokens
249+
else:
250+
model_kwargs["max_tokens"] = llm_config.max_tokens
248251
if llm_config.seed is not None:
249252
model_kwargs["seed"] = llm_config.seed
250253

packages/nvidia_nat_langchain/tests/test_llm_langchain.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,30 @@ async def test_basic_creation(self, mock_get_chat, mock_create_client_kwargs, mo
214214
}
215215
assert client is mock_chat_class.return_value
216216

217+
@patch("oci.generative_ai_inference.GenerativeAiInferenceClient")
218+
@patch("langchain_oci.common.auth.create_oci_client_kwargs")
219+
@patch("nat.plugins.langchain.llm._get_langchain_oci_chat_model")
220+
async def test_openai_provider_uses_max_completion_tokens(
221+
self, mock_get_chat, mock_create_client_kwargs, mock_oci_client, mock_builder
222+
):
223+
mock_chat_class = MagicMock()
224+
mock_get_chat.return_value = mock_chat_class
225+
mock_create_client_kwargs.return_value = {"config": {"region": "us-chicago-1"}}
226+
227+
cfg = OCIModelConfig(
228+
model_name="openai.gpt-5.4",
229+
compartment_id="ocid1.compartment.oc1..example",
230+
endpoint="https://inference.generativeai.us-chicago-1.oci.oraclecloud.com",
231+
provider="openai",
232+
max_tokens=128,
233+
)
234+
235+
async with oci_langchain(cfg, mock_builder) as client:
236+
kwargs = mock_chat_class.call_args.kwargs
237+
assert "max_completion_tokens" in kwargs["model_kwargs"]
238+
assert "max_tokens" not in kwargs["model_kwargs"]
239+
assert kwargs["model_kwargs"]["max_completion_tokens"] == 128
240+
217241
@patch("nat.plugins.langchain.llm._get_langchain_oci_chat_model")
218242
async def test_api_type_validation(self, mock_get_chat, oci_cfg_wrong_api, mock_builder):
219243
with pytest.raises(ValueError):

0 commit comments

Comments
 (0)