|
| 1 | +# SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
| 2 | +# SPDX-License-Identifier: Apache-2.0 |
| 3 | +# |
| 4 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +# you may not use this file except in compliance with the License. |
| 6 | +# You may obtain a copy of the License at |
| 7 | +# |
| 8 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +# |
| 10 | +# Unless required by applicable law or agreed to in writing, software |
| 11 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +# See the License for the specific language governing permissions and |
| 14 | +# limitations under the License. |
| 15 | + |
| 16 | +from collections.abc import AsyncIterator |
| 17 | + |
| 18 | +from pydantic import AliasChoices |
| 19 | +from pydantic import ConfigDict |
| 20 | +from pydantic import Field |
| 21 | + |
| 22 | +from nat.builder.builder import Builder |
| 23 | +from nat.builder.llm import LLMProviderInfo |
| 24 | +from nat.cli.register_workflow import register_llm_provider |
| 25 | +from nat.data_models.llm import LLMBaseConfig |
| 26 | +from nat.data_models.optimizable import OptimizableField |
| 27 | +from nat.data_models.optimizable import OptimizableMixin |
| 28 | +from nat.data_models.optimizable import SearchSpace |
| 29 | +from nat.data_models.retry_mixin import RetryMixin |
| 30 | +from nat.data_models.ssl_verification_mixin import SSLVerificationMixin |
| 31 | +from nat.data_models.thinking_mixin import ThinkingMixin |
| 32 | + |
| 33 | +class OCIModelConfig(LLMBaseConfig, RetryMixin, OptimizableMixin, ThinkingMixin, SSLVerificationMixin, name="oci"): |
| 34 | + """OCI Generative AI LLM provider.""" |
| 35 | + |
| 36 | + model_config = ConfigDict(protected_namespaces=(), extra="allow") |
| 37 | + |
| 38 | + endpoint: str | None = Field( |
| 39 | + default=None, |
| 40 | + validation_alias=AliasChoices("endpoint", "service_endpoint", "base_url"), |
| 41 | + description="OCI Generative AI service endpoint URL.", |
| 42 | + ) |
| 43 | + compartment_id: str | None = Field(default=None, description="OCI compartment OCID for Generative AI requests.") |
| 44 | + auth_type: str = Field(default="API_KEY", |
| 45 | + description="OCI SDK authentication type: API_KEY, SECURITY_TOKEN, INSTANCE_PRINCIPAL, " |
| 46 | + "or RESOURCE_PRINCIPAL.") |
| 47 | + auth_profile: str = Field(default="DEFAULT", |
| 48 | + description="OCI config profile to use for API_KEY or SECURITY_TOKEN auth.") |
| 49 | + auth_file_location: str = Field(default="~/.oci/config", |
| 50 | + description="Path to the OCI config file used for SDK authentication.") |
| 51 | + model_name: str = OptimizableField(validation_alias=AliasChoices("model_name", "model"), |
| 52 | + serialization_alias="model", |
| 53 | + description="The OCI Generative AI model ID.") |
| 54 | + provider: str | None = Field(default=None, |
| 55 | + description="Optional OCI provider override such as cohere, google, meta, or openai.") |
| 56 | + context_size: int | None = Field( |
| 57 | + default=1024, |
| 58 | + gt=0, |
| 59 | + description="The maximum number of tokens available for input.", |
| 60 | + ) |
| 61 | + seed: int | None = Field(default=None, description="Random seed to set for generation.") |
| 62 | + max_retries: int = Field(default=10, description="The max number of retries for the request.") |
| 63 | + max_tokens: int | None = Field(default=None, gt=0, description="Maximum number of output tokens.") |
| 64 | + temperature: float | None = OptimizableField( |
| 65 | + default=None, |
| 66 | + ge=0.0, |
| 67 | + description="Sampling temperature to control randomness in the output.", |
| 68 | + space=SearchSpace(high=0.9, low=0.1, step=0.2)) |
| 69 | + top_p: float | None = OptimizableField(default=None, |
| 70 | + ge=0.0, |
| 71 | + le=1.0, |
| 72 | + description="Top-p for distribution sampling.", |
| 73 | + space=SearchSpace(high=1.0, low=0.5, step=0.1)) |
| 74 | + request_timeout: float | None = Field(default=None, gt=0.0, description="HTTP request timeout in seconds.") |
| 75 | + |
| 76 | + |
| 77 | +@register_llm_provider(config_type=OCIModelConfig) |
| 78 | +async def oci_llm(config: OCIModelConfig, _builder: Builder) -> AsyncIterator[LLMProviderInfo]: |
| 79 | + """Yield provider metadata for an OCI Generative AI model. |
| 80 | +
|
| 81 | + Args: |
| 82 | + config: OCI model configuration. |
| 83 | + _builder: Builder instance. |
| 84 | +
|
| 85 | + Yields: |
| 86 | + LLMProviderInfo describing the configured OCI model. |
| 87 | + """ |
| 88 | + |
| 89 | + yield LLMProviderInfo(config=config, description="An OCI Generative AI model for use with an LLM client.") |
0 commit comments