diff --git a/docs/docs/providers/agents/index.mdx b/docs/docs/providers/agents/index.mdx index 06eb104afa..52b92734e6 100644 --- a/docs/docs/providers/agents/index.mdx +++ b/docs/docs/providers/agents/index.mdx @@ -1,7 +1,7 @@ --- description: "Agents - APIs for creating and interacting with agentic systems." +APIs for creating and interacting with agentic systems." sidebar_label: Agents title: Agents --- @@ -12,6 +12,6 @@ title: Agents Agents - APIs for creating and interacting with agentic systems. +APIs for creating and interacting with agentic systems. This section contains documentation for all available providers for the **agents** API. diff --git a/docs/docs/providers/batches/index.mdx b/docs/docs/providers/batches/index.mdx index 2c64b277f8..18e5e314d2 100644 --- a/docs/docs/providers/batches/index.mdx +++ b/docs/docs/providers/batches/index.mdx @@ -1,14 +1,14 @@ --- description: "The Batches API enables efficient processing of multiple requests in a single operation, - particularly useful for processing large datasets, batch evaluation workflows, and - cost-effective inference at scale. +particularly useful for processing large datasets, batch evaluation workflows, and +cost-effective inference at scale. - The API is designed to allow use of openai client libraries for seamless integration. +The API is designed to allow use of openai client libraries for seamless integration. - This API provides the following extensions: - - idempotent batch creation +This API provides the following extensions: + - idempotent batch creation - Note: This API is currently under active development and may undergo changes." +Note: This API is currently under active development and may undergo changes." sidebar_label: Batches title: Batches --- @@ -18,14 +18,14 @@ title: Batches ## Overview The Batches API enables efficient processing of multiple requests in a single operation, - particularly useful for processing large datasets, batch evaluation workflows, and - cost-effective inference at scale. +particularly useful for processing large datasets, batch evaluation workflows, and +cost-effective inference at scale. - The API is designed to allow use of openai client libraries for seamless integration. +The API is designed to allow use of openai client libraries for seamless integration. - This API provides the following extensions: - - idempotent batch creation +This API provides the following extensions: + - idempotent batch creation - Note: This API is currently under active development and may undergo changes. +Note: This API is currently under active development and may undergo changes. This section contains documentation for all available providers for the **batches** API. diff --git a/docs/docs/providers/eval/index.mdx b/docs/docs/providers/eval/index.mdx index 94bafe15e7..45fc5ebd3d 100644 --- a/docs/docs/providers/eval/index.mdx +++ b/docs/docs/providers/eval/index.mdx @@ -1,7 +1,7 @@ --- description: "Evaluations - Llama Stack Evaluation API for running evaluations on model and agent candidates." +Llama Stack Evaluation API for running evaluations on model and agent candidates." sidebar_label: Eval title: Eval --- @@ -12,6 +12,6 @@ title: Eval Evaluations - Llama Stack Evaluation API for running evaluations on model and agent candidates. +Llama Stack Evaluation API for running evaluations on model and agent candidates. This section contains documentation for all available providers for the **eval** API. diff --git a/docs/docs/providers/files/index.mdx b/docs/docs/providers/files/index.mdx index 19e338035b..c61c4f1b6d 100644 --- a/docs/docs/providers/files/index.mdx +++ b/docs/docs/providers/files/index.mdx @@ -1,7 +1,7 @@ --- description: "Files - This API is used to upload documents that can be used with other Llama Stack APIs." +This API is used to upload documents that can be used with other Llama Stack APIs." sidebar_label: Files title: Files --- @@ -12,6 +12,6 @@ title: Files Files - This API is used to upload documents that can be used with other Llama Stack APIs. +This API is used to upload documents that can be used with other Llama Stack APIs. This section contains documentation for all available providers for the **files** API. diff --git a/docs/docs/providers/inference/index.mdx b/docs/docs/providers/inference/index.mdx index c2bf69962a..322c95ee7e 100644 --- a/docs/docs/providers/inference/index.mdx +++ b/docs/docs/providers/inference/index.mdx @@ -1,11 +1,11 @@ --- description: "Inference - Llama Stack Inference API for generating completions, chat completions, and embeddings. +Llama Stack Inference API for generating completions, chat completions, and embeddings. - This API provides the raw interface to the underlying models. Two kinds of models are supported: - - LLM models: these models generate \"raw\" and \"chat\" (conversational) completions. - - Embedding models: these models generate embeddings to be used for semantic search." +This API provides the raw interface to the underlying models. Two kinds of models are supported: +- LLM models: these models generate \"raw\" and \"chat\" (conversational) completions. +- Embedding models: these models generate embeddings to be used for semantic search." sidebar_label: Inference title: Inference --- @@ -16,10 +16,10 @@ title: Inference Inference - Llama Stack Inference API for generating completions, chat completions, and embeddings. +Llama Stack Inference API for generating completions, chat completions, and embeddings. - This API provides the raw interface to the underlying models. Two kinds of models are supported: - - LLM models: these models generate "raw" and "chat" (conversational) completions. - - Embedding models: these models generate embeddings to be used for semantic search. +This API provides the raw interface to the underlying models. Two kinds of models are supported: +- LLM models: these models generate "raw" and "chat" (conversational) completions. +- Embedding models: these models generate embeddings to be used for semantic search. This section contains documentation for all available providers for the **inference** API. diff --git a/docs/docs/providers/safety/index.mdx b/docs/docs/providers/safety/index.mdx index 4e2de4f331..038565475d 100644 --- a/docs/docs/providers/safety/index.mdx +++ b/docs/docs/providers/safety/index.mdx @@ -1,7 +1,7 @@ --- description: "Safety - OpenAI-compatible Moderations API." +OpenAI-compatible Moderations API." sidebar_label: Safety title: Safety --- @@ -12,6 +12,6 @@ title: Safety Safety - OpenAI-compatible Moderations API. +OpenAI-compatible Moderations API. This section contains documentation for all available providers for the **safety** API. diff --git a/llama_stack/distributions/oci/__init__.py b/llama_stack/distributions/oci/__init__.py new file mode 100644 index 0000000000..94d103a022 --- /dev/null +++ b/llama_stack/distributions/oci/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .oci import get_distribution_template + +__all__ = ["get_distribution_template"] diff --git a/llama_stack/distributions/oci/build.yaml b/llama_stack/distributions/oci/build.yaml new file mode 100644 index 0000000000..49da82c3fc --- /dev/null +++ b/llama_stack/distributions/oci/build.yaml @@ -0,0 +1,34 @@ +version: 2 +distribution_spec: + description: Use Oracle Cloud Infrastructure (OCI) Generative AI for running LLM + inference with scalable cloud services + providers: + inference: + - provider_type: remote::oci + vector_io: + - provider_type: inline::faiss + - provider_type: remote::chromadb + - provider_type: remote::pgvector + safety: + - provider_type: inline::llama-guard + agents: + - provider_type: inline::meta-reference + eval: + - provider_type: inline::meta-reference + datasetio: + - provider_type: remote::huggingface + - provider_type: inline::localfs + scoring: + - provider_type: inline::basic + - provider_type: inline::llm-as-judge + - provider_type: inline::braintrust + tool_runtime: + - provider_type: remote::brave-search + - provider_type: remote::tavily-search + - provider_type: remote::model-context-protocol + files: + - provider_type: inline::localfs +image_type: venv +additional_pip_packages: +- aiosqlite +- sqlalchemy[asyncio] diff --git a/llama_stack/distributions/oci/oci.py b/llama_stack/distributions/oci/oci.py new file mode 100644 index 0000000000..6367cea549 --- /dev/null +++ b/llama_stack/distributions/oci/oci.py @@ -0,0 +1,123 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from pathlib import Path + +from llama_stack.core.datatypes import BuildProvider, Provider, ToolGroupInput +from llama_stack.distributions.template import DistributionTemplate, RunConfigSettings +from llama_stack.providers.inline.files.localfs.config import LocalfsFilesImplConfig +from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig +from llama_stack.providers.remote.inference.oci.config import OCIConfig + + +def get_distribution_template(name: str = "oci") -> DistributionTemplate: + providers = { + "inference": [BuildProvider(provider_type="remote::oci")], + "vector_io": [ + BuildProvider(provider_type="inline::faiss"), + BuildProvider(provider_type="remote::chromadb"), + BuildProvider(provider_type="remote::pgvector"), + ], + "safety": [BuildProvider(provider_type="inline::llama-guard")], + "agents": [BuildProvider(provider_type="inline::meta-reference")], + "eval": [BuildProvider(provider_type="inline::meta-reference")], + "datasetio": [ + BuildProvider(provider_type="remote::huggingface"), + BuildProvider(provider_type="inline::localfs"), + ], + "scoring": [ + BuildProvider(provider_type="inline::basic"), + BuildProvider(provider_type="inline::llm-as-judge"), + BuildProvider(provider_type="inline::braintrust"), + ], + "tool_runtime": [ + BuildProvider(provider_type="remote::brave-search"), + BuildProvider(provider_type="remote::tavily-search"), + BuildProvider(provider_type="remote::model-context-protocol"), + ], + "files": [BuildProvider(provider_type="inline::localfs")], + } + + inference_provider = Provider( + provider_id="oci", + provider_type="remote::oci", + config=OCIConfig.sample_run_config(), + ) + + files_provider = Provider( + provider_id="meta-reference-files", + provider_type="inline::localfs", + config=LocalfsFilesImplConfig.sample_run_config(f"~/.llama/distributions/{name}"), + ) + vector_io_provider = Provider( + provider_id="faiss", + provider_type="inline::faiss", + config=FaissVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"), + ) + + default_tool_groups = [ + ToolGroupInput( + toolgroup_id="builtin::websearch", + provider_id="tavily-search", + ), + ] + + return DistributionTemplate( + name=name, + distro_type="remote_hosted", + description="Use Oracle Cloud Infrastructure (OCI) Generative AI for running LLM inference with scalable cloud services", + container_image=None, + template_path=Path(__file__).parent / "doc_template.md", + providers=providers, + run_configs={ + "run.yaml": RunConfigSettings( + provider_overrides={ + "inference": [inference_provider], + "vector_io": [vector_io_provider], + "files": [files_provider], + }, + default_tool_groups=default_tool_groups, + ), + }, + run_config_env_vars={ + "OCI_AUTH_TYPE": ( + "instance_principal", + "OCI authentication type (instance_principal or config_file)", + ), + "OCI_USER_OCID": ( + "", + "OCI user OCID for authentication", + ), + "OCI_TENANCY_OCID": ( + "", + "OCI tenancy OCID for authentication", + ), + "OCI_FINGERPRINT": ( + "", + "OCI API key fingerprint for authentication", + ), + "OCI_PRIVATE_KEY": ( + "", + "OCI private key for authentication", + ), + "OCI_REGION": ( + "", + "OCI region (e.g., us-ashburn-1, us-chicago-1, us-phoenix-1, eu-frankfurt-1)", + ), + "OCI_COMPARTMENT_OCID": ( + "", + "OCI compartment ID for the Generative AI service", + ), + "OCI_CONFIG_FILE_PATH": ( + "~/.oci/config", + "OCI config file path (required if OCI_AUTH_TYPE is config_file)", + ), + "OCI_CLI_PROFILE": ( + "DEFAULT", + "OCI CLI profile name to use from config file", + ), + }, + ) diff --git a/llama_stack/distributions/oci/run.yaml b/llama_stack/distributions/oci/run.yaml new file mode 100644 index 0000000000..b0c8353723 --- /dev/null +++ b/llama_stack/distributions/oci/run.yaml @@ -0,0 +1,136 @@ +version: 2 +image_name: oci +apis: +- agents +- datasetio +- eval +- files +- inference +- safety +- scoring +- tool_runtime +- vector_io +providers: + inference: + - provider_id: oci + provider_type: remote::oci + config: + oci_auth_type: ${env.OCI_AUTH_TYPE:=instance_principal} + oci_config_file_path: ${env.OCI_CONFIG_FILE_PATH:=~/.oci/config} + oci_config_profile: ${env.OCI_CLI_PROFILE:=DEFAULT} + oci_region: ${env.OCI_REGION:=us-ashburn-1} + oci_compartment_id: ${env.OCI_COMPARTMENT_OCID:=} + oci_serving_mode: ${env.OCI_SERVING_MODE:=ON_DEMAND} + oci_user_ocid: ${env.OCI_USER_OCID:=} + oci_tenancy_ocid: ${env.OCI_TENANCY_OCID:=} + oci_fingerprint: ${env.OCI_FINGERPRINT:=} + oci_private_key: ${env.OCI_PRIVATE_KEY:=} + vector_io: + - provider_id: faiss + provider_type: inline::faiss + config: + persistence: + namespace: vector_io::faiss + backend: kv_default + safety: + - provider_id: llama-guard + provider_type: inline::llama-guard + config: + excluded_categories: [] + agents: + - provider_id: meta-reference + provider_type: inline::meta-reference + config: + persistence: + agent_state: + namespace: agents + backend: kv_default + responses: + table_name: responses + backend: sql_default + max_write_queue_size: 10000 + num_writers: 4 + eval: + - provider_id: meta-reference + provider_type: inline::meta-reference + config: + kvstore: + namespace: eval + backend: kv_default + datasetio: + - provider_id: huggingface + provider_type: remote::huggingface + config: + kvstore: + namespace: datasetio::huggingface + backend: kv_default + - provider_id: localfs + provider_type: inline::localfs + config: + kvstore: + namespace: datasetio::localfs + backend: kv_default + scoring: + - provider_id: basic + provider_type: inline::basic + - provider_id: llm-as-judge + provider_type: inline::llm-as-judge + - provider_id: braintrust + provider_type: inline::braintrust + config: + openai_api_key: ${env.OPENAI_API_KEY:=} + tool_runtime: + - provider_id: brave-search + provider_type: remote::brave-search + config: + api_key: ${env.BRAVE_SEARCH_API_KEY:=} + max_results: 3 + - provider_id: tavily-search + provider_type: remote::tavily-search + config: + api_key: ${env.TAVILY_SEARCH_API_KEY:=} + max_results: 3 + - provider_id: model-context-protocol + provider_type: remote::model-context-protocol + files: + - provider_id: meta-reference-files + provider_type: inline::localfs + config: + storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/distributions/oci/files} + metadata_store: + table_name: files_metadata + backend: sql_default +storage: + backends: + kv_default: + type: kv_sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/oci}/kvstore.db + sql_default: + type: sql_sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/oci}/sql_store.db + stores: + metadata: + namespace: registry + backend: kv_default + inference: + table_name: inference_store + backend: sql_default + max_write_queue_size: 10000 + num_writers: 4 + conversations: + table_name: openai_conversations + backend: sql_default +registered_resources: + models: [] + shields: [] + vector_dbs: [] + datasets: [] + scoring_fns: [] + benchmarks: [] + tool_groups: + - toolgroup_id: builtin::websearch + provider_id: tavily-search +server: + port: 8321 +telemetry: + enabled: true diff --git a/llama_stack/providers/registry/inference.py b/llama_stack/providers/registry/inference.py index 2e52e2d12a..a0c1df82b9 100644 --- a/llama_stack/providers/registry/inference.py +++ b/llama_stack/providers/registry/inference.py @@ -293,6 +293,20 @@ def available_providers() -> list[ProviderSpec]: Azure OpenAI inference provider for accessing GPT models and other Azure services. Provider documentation https://learn.microsoft.com/en-us/azure/ai-foundry/openai/overview +""", + ), + RemoteProviderSpec( + api=Api.inference, + provider_type="remote::oci", + adapter_type="oci", + pip_packages=["oci"], + module="llama_stack.providers.remote.inference.oci", + config_class="llama_stack.providers.remote.inference.oci.config.OCIConfig", + provider_data_validator="llama_stack.providers.remote.inference.oci.config.OCIProviderDataValidator", + description=""" +Oracle Cloud Infrastructure (OCI) Generative AI inference provider for accessing OCI's Generative AI Platform-as-a-Service models. +Provider documentation +https://docs.oracle.com/en-us/iaas/Content/generative-ai/home.htm """, ), ] diff --git a/llama_stack/providers/remote/inference/oci/__init__.py b/llama_stack/providers/remote/inference/oci/__init__.py new file mode 100644 index 0000000000..7c0b6460f2 --- /dev/null +++ b/llama_stack/providers/remote/inference/oci/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from llama_stack.apis.inference import InferenceProvider + +from .config import OCIConfig + + +async def get_adapter_impl(config: OCIConfig, _deps) -> InferenceProvider: + from .oci import OCIInferenceAdapter + + adapter = OCIInferenceAdapter(config) + return adapter diff --git a/llama_stack/providers/remote/inference/oci/config.py b/llama_stack/providers/remote/inference/oci/config.py new file mode 100644 index 0000000000..6b5e0970c9 --- /dev/null +++ b/llama_stack/providers/remote/inference/oci/config.py @@ -0,0 +1,123 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import os +from typing import Any + +from pydantic import BaseModel, Field + +from llama_stack.schema_utils import json_schema_type + + +class OCIProviderDataValidator(BaseModel): + oci_auth_type: str = Field( + description="OCI authentication type (must be one of: instance_principal, config_file)", + ) + oci_private_key: str | None = Field( + description="OCI private key for authentication", + ) + oci_config_file_path: str | None = Field( + default=None, + description="OCI config file path (required if oci_auth_type is config_file)", + ) + oci_config_profile: str = Field( + description="OCI config profile (required if oci_auth_type is config_file)", + ) + oci_region: str | None = Field( + default=None, + description="OCI region (e.g., us-ashburn-1)", + ) + oci_compartment_id: str | None = Field( + default=None, + description="OCI compartment ID for the Generative AI service", + ) + oci_user_ocid: str | None = Field( + default=None, + description="OCI user OCID for authentication", + ) + oci_tenancy_ocid: str | None = Field( + default=None, + description="OCI tenancy OCID for authentication", + ) + oci_fingerprint: str | None = Field( + default=None, + description="OCI API key fingerprint for authentication", + ) + oci_serving_mode: str = Field( + default="ON_DEMAND", + description="OCI serving mode (must be one of: ON_DEMAND, DEDICATED)", + ) + + +@json_schema_type +class OCIConfig(BaseModel): + oci_auth_type: str = Field( + description="OCI authentication type (must be one of: instance_principal, config_file)", + default_factory=lambda: os.getenv("OCI_AUTH_TYPE", "instance_principal"), + ) + oci_config_file_path: str = Field( + default_factory=lambda: os.getenv("OCI_CONFIG_FILE_PATH", "~/.oci/config"), + description="OCI config file path (required if oci_auth_type is config_file)", + ) + oci_config_profile: str = Field( + default_factory=lambda: os.getenv("OCI_CLI_PROFILE", "DEFAULT"), + description="OCI config profile (required if oci_auth_type is config_file)", + ) + oci_region: str | None = Field( + default_factory=lambda: os.getenv("OCI_REGION"), + description="OCI region (e.g., us-ashburn-1)", + ) + oci_compartment_id: str | None = Field( + default_factory=lambda: os.getenv("OCI_COMPARTMENT_OCID"), + description="OCI compartment ID for the Generative AI service", + ) + oci_user_ocid: str | None = Field( + default_factory=lambda: os.getenv("OCI_USER_OCID"), + description="OCI user OCID for authentication", + ) + oci_tenancy_ocid: str | None = Field( + default_factory=lambda: os.getenv("OCI_TENANCY_OCID"), + description="OCI tenancy OCID for authentication", + ) + oci_fingerprint: str | None = Field( + default_factory=lambda: os.getenv("OCI_FINGERPRINT"), + description="OCI API key fingerprint for authentication", + ) + oci_private_key: str | None = Field( + description="OCI private key for authentication", + ) + oci_serving_mode: str = Field( + default_factory=lambda: os.getenv("OCI_SERVING_MODE", "ON_DEMAND"), + description="OCI serving mode (must be one of: ON_DEMAND, DEDICATED)", + ) + + @classmethod + def sample_run_config( + cls, + oci_auth_type: str = "${env.OCI_AUTH_TYPE:=instance_principal}", + oci_config_file_path: str = "${env.OCI_CONFIG_FILE_PATH:=~/.oci/config}", + oci_config_profile: str = "${env.OCI_CLI_PROFILE:=DEFAULT}", + oci_region: str = "${env.OCI_REGION:=us-ashburn-1}", + oci_compartment_id: str = "${env.OCI_COMPARTMENT_OCID:=}", + oci_serving_mode: str = "${env.OCI_SERVING_MODE:=ON_DEMAND}", + oci_user_ocid: str = "${env.OCI_USER_OCID:=}", + oci_tenancy_ocid: str = "${env.OCI_TENANCY_OCID:=}", + oci_fingerprint: str = "${env.OCI_FINGERPRINT:=}", + oci_private_key: str = "${env.OCI_PRIVATE_KEY:=}", + **kwargs, + ) -> dict[str, Any]: + return { + "oci_auth_type": oci_auth_type, + "oci_config_file_path": oci_config_file_path, + "oci_config_profile": oci_config_profile, + "oci_region": oci_region, + "oci_compartment_id": oci_compartment_id, + "oci_serving_mode": oci_serving_mode, + "oci_user_ocid": oci_user_ocid, + "oci_tenancy_ocid": oci_tenancy_ocid, + "oci_fingerprint": oci_fingerprint, + "oci_private_key": oci_private_key, + } diff --git a/llama_stack/providers/remote/inference/oci/models.py b/llama_stack/providers/remote/inference/oci/models.py new file mode 100644 index 0000000000..affb1e4ee8 --- /dev/null +++ b/llama_stack/providers/remote/inference/oci/models.py @@ -0,0 +1,103 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import oci +from oci.generative_ai.generative_ai_client import GenerativeAiClient +from oci.generative_ai.models import ModelCollection + +from llama_stack.apis.models import ModelType +from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper, ProviderModelEntry + + +def build_oci_model_entries( + compartment_id: str, + oci_config: dict | None = None, + oci_signer: oci.auth.signers.InstancePrincipalsSecurityTokenSigner | None = None, +) -> list[ProviderModelEntry]: + """ + Build OCI model entries by fetching available models from OCI Generative AI service. + + Args: + compartment_id: OCI compartment ID where models are located + config: OCI config dict (if None, will use config file) + config_profile: OCI config profile to use (defaults to "CHICAGO" or OCI_CLI_PROFILE env var) + + Returns: + List of ProviderModelEntry objects mapping display_name to model.id + """ + if oci_signer is None: + client = GenerativeAiClient(config=oci_config) + else: + client = GenerativeAiClient(config=oci_config, signer=oci_signer) + + print("Here 1") + models: ModelCollection = client.list_models(compartment_id=compartment_id, capability=["CHAT"]).data + + model_entries = [] + seen_models = set() + for model in models.items: + if model.time_deprecated or model.time_on_demand_retired: + continue + + if "CHAT" not in model.capabilities or "FINE_TUNE" in model.capabilities: + continue + + if model.display_name in seen_models: + continue + + seen_models.add(model.display_name) + + entry = ProviderModelEntry( + provider_model_id=model.id, + aliases=[model.display_name], + model_type=ModelType.llm, + metadata={ + "display_name": model.display_name, + "capabilities": model.capabilities, + "oci_model_id": model.id, + }, + ) + + seen_models.add(model.display_name) + + model_entries.append(entry) + + return model_entries + + +class OCIModelRegistryHelper(ModelRegistryHelper): + """ + OCI-specific model registry helper that dynamically fetches models from OCI. + """ + + def __init__( + self, + compartment_id: str, + oci_config: dict | None = None, + oci_signer: oci.auth.signers.InstancePrincipalsSecurityTokenSigner | None = None, + allowed_models: list[str] | None = None, + ): + model_entries = build_oci_model_entries(compartment_id, oci_config, oci_signer) + + super().__init__(model_entries=model_entries, allowed_models=allowed_models) + + self.compartment_id = compartment_id + self.oci_config = oci_config + self.oci_signer = oci_signer + + async def should_refresh_models(self) -> bool: + return True + + async def check_model_availability(self, alias: str) -> bool: + client = GenerativeAiClient(config=self.oci_config, signer=self.oci_signer) + model_id = self.get_provider_model_id(alias) + response = client.get_model(model_id) + return response.data is not None + + +# For backward compatibility, create an empty MODEL_ENTRIES list +# The actual models will be built dynamically by OCIModelRegistryHelper +MODEL_ENTRIES: list[ProviderModelEntry] = [] diff --git a/llama_stack/providers/remote/inference/oci/oci.py b/llama_stack/providers/remote/inference/oci/oci.py new file mode 100644 index 0000000000..0335a8e837 --- /dev/null +++ b/llama_stack/providers/remote/inference/oci/oci.py @@ -0,0 +1,447 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import json +import time +from collections.abc import AsyncGenerator, AsyncIterator + +from llama_stack.log import get_logger + +logger = get_logger(__name__) + +import oci +from oci.generative_ai_inference import GenerativeAiInferenceClient +from oci.generative_ai_inference.models import ( + ChatDetails, + ChatResult, + DedicatedServingMode, + GenericChatRequest, + OnDemandServingMode, + SystemMessage, + TextContent, + UserMessage, +) + +from llama_stack.apis.inference import ( + ChatCompletionRequest, + ChatCompletionResponse, + ChatCompletionResponseStreamChunk, + Inference, + LogProbConfig, + Message, + ResponseFormat, + SamplingParams, + ToolChoice, + ToolConfig, + ToolDefinition, + ToolPromptFormat, +) +from llama_stack.apis.inference.inference import ( + OpenAIAssistantMessageParam, + OpenAIChatCompletion, + OpenAIChatCompletionChunk, + OpenAIChatCompletionRequestWithExtraBody, + OpenAIChatCompletionUsage, + OpenAIChoice, + OpenAIChoiceDelta, + OpenAIChunkChoice, + OpenAICompletion, + OpenAICompletionRequestWithExtraBody, + OpenAIEmbeddingsRequestWithExtraBody, + OpenAIEmbeddingsResponse, +) +from llama_stack.log import get_logger +from llama_stack.providers.utils.inference.openai_compat import ( + OpenAICompatCompletionChoice, + OpenAICompatCompletionResponse, + prepare_openai_completion_params, + process_chat_completion_response, + process_chat_completion_stream_response, +) + +from .config import OCIConfig +from .models import OCIModelRegistryHelper + +logger = get_logger(name=__name__, category="inference::oci") + +OCI_AUTH_TYPE_INSTANCE_PRINCIPAL = "instance_principal" +OCI_AUTH_TYPE_CONFIG_FILE = "config_file" +VALID_OCI_AUTH_TYPES = [OCI_AUTH_TYPE_INSTANCE_PRINCIPAL, OCI_AUTH_TYPE_CONFIG_FILE] + +OCI_SERVING_MODE_ON_DEMAND = "ON_DEMAND" +OCI_SERVING_MODE_DEDICATED = "DEDICATED" +VALID_OCI_SERVING_MODES = [OCI_SERVING_MODE_ON_DEMAND, OCI_SERVING_MODE_DEDICATED] + + +class OCIInferenceAdapter(Inference, OCIModelRegistryHelper): + def __init__(self, config: OCIConfig) -> None: + self.config = config + self._client: GenerativeAiInferenceClient | None = None + + if self.config.oci_auth_type not in VALID_OCI_AUTH_TYPES: + raise ValueError( + f"Invalid OCI authentication type: {self.config.oci_auth_type}." + f"Valid types are one of: {VALID_OCI_AUTH_TYPES}" + ) + + if not self.config.oci_compartment_id: + raise ValueError("OCI_COMPARTMENT_OCID a required parameter. Either set in env variable.") + + if self.config.oci_serving_mode not in VALID_OCI_SERVING_MODES: + raise ValueError( + f"Invalid OCI serving mode: {self.config.oci_serving_mode}." + f"Valid modes are one of: {VALID_OCI_SERVING_MODES}" + ) + + # Initialize with OCI-specific model registry helper after validation + + OCIModelRegistryHelper.__init__( + self, + compartment_id=config.oci_compartment_id or "", + oci_config=self._get_oci_config(), + oci_signer=self._get_oci_signer(), + ) + + @property + def client(self) -> GenerativeAiInferenceClient: + if self._client is None: + self._client = self._get_client() + return self._client + + def _get_oci_config(self) -> dict: + if self.config.oci_auth_type == OCI_AUTH_TYPE_INSTANCE_PRINCIPAL: + config = {"region": self.config.oci_region} + elif self.config.oci_auth_type == OCI_AUTH_TYPE_CONFIG_FILE: + config = oci.config.from_file(self.config.oci_config_file_path, self.config.oci_config_profile) + if not config.get("region"): + raise ValueError( + "Region not specified in config. Please specify in config or with OCI_REGION env variable." + ) + + return config + + def _get_oci_signer(self) -> oci.auth.signers.InstancePrincipalsSecurityTokenSigner | None: + if self.config.oci_auth_type == OCI_AUTH_TYPE_INSTANCE_PRINCIPAL: + return oci.auth.signers.InstancePrincipalsSecurityTokenSigner() + return None + + def _get_client(self) -> GenerativeAiInferenceClient: + if self._client is None: + if self._get_oci_signer() is None: + return GenerativeAiInferenceClient(config=self._get_oci_config()) + else: + return GenerativeAiInferenceClient( + config=self._get_oci_config(), + signer=self._get_oci_signer(), + ) + return self._client + + async def initialize(self) -> None: + pass + + async def shutdown(self) -> None: + pass + + async def _build_chat_details(self, request: ChatCompletionRequest) -> ChatDetails: + messages = [] + system_messages = [] + user_messages = [] + + for msg in request.messages: + if msg.role == "system": + system_messages.append(SystemMessage(name="System", content=[TextContent(text=msg.content)])) + elif msg.role == "user": + user_messages.append(UserMessage(name="User", content=[TextContent(text=msg.content)])) + + messages.extend(system_messages) + messages.extend(user_messages) + + # Create chat request + sampling_params: SamplingParams | None = request.sampling_params if request.sampling_params else None + chat_request = GenericChatRequest( + api_format="GENERIC", + messages=messages, + is_stream=request.stream, + num_generations=1, + seed=42, + is_echo=False, + top_k=-1, + top_p=0.95, + temperature=0.7, + frequency_penalty=0, + presence_penalty=sampling_params.repetition_penalty if sampling_params else 0, + max_tokens=sampling_params.max_tokens if sampling_params else 512, + stop=sampling_params.stop if sampling_params else None, + ) + + model_id = self.get_provider_model_id(request.model) + if self.config.oci_serving_mode == OCI_SERVING_MODE_ON_DEMAND: + serving_mode = OnDemandServingMode(serving_type="ON_DEMAND", model_id=model_id) + elif self.config.oci_serving_mode == OCI_SERVING_MODE_DEDICATED: + serving_mode = DedicatedServingMode(serving_type="DEDICATED", model_id=model_id) + + chat_details = ChatDetails( + compartment_id=self.config.oci_compartment_id, serving_mode=serving_mode, chat_request=chat_request + ) + return chat_details + + async def chat_completion( + self, + model_id: str, + messages: list[Message], + sampling_params: SamplingParams | None = None, + tools: list[ToolDefinition] | None = None, + tool_choice: ToolChoice | None = ToolChoice.auto, + tool_prompt_format: ToolPromptFormat | None = None, + response_format: ResponseFormat | None = None, + stream: bool | None = False, + logprobs: LogProbConfig | None = None, + tool_config: ToolConfig | None = None, + ) -> ChatCompletionResponse | AsyncIterator[ChatCompletionResponseStreamChunk]: + if sampling_params is None: + sampling_params = SamplingParams() + request = ChatCompletionRequest( + model=model_id, + messages=messages, + sampling_params=sampling_params, + tools=tools or [], + response_format=response_format, + stream=stream, + logprobs=logprobs, + tool_config=tool_config, + ) + chat_details = await self._build_chat_details(request) + if stream: + return self._stream_chat_completion(request, chat_details) + else: + return await self._nonstream_chat_completion(request, chat_details) + + async def _stream_chat_completion( + self, request: ChatCompletionRequest, details: ChatDetails + ) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None]: + """ + Perform non-streaming chat completion using OCI Generative AI. + """ + response = self._get_client().chat(details) + stream = response.data + + async def _generate_and_convert_to_openai_compat(): + for chunk in stream.events(): + # {'index': 0, 'message': {'role': 'ASSISTANT', 'content': [{'type': 'TEXT', 'text': ' knowledge'}]}, 'pad': 'aaaaaaaa'} + # {'message': {'role': 'ASSISTANT'}, 'finishReason': 'stop', 'pad': 'aaaaaaaaaaaa'} + data = json.loads(chunk.data) + finish_reason = data.get("finishReason", None) + message_content = data.get("message", {}).get("content", []) + text = "" + if message_content: + text = message_content[0].get("text", "") + choice = OpenAICompatCompletionChoice(finish_reason=finish_reason, text=text) + yield OpenAICompatCompletionResponse(choices=[choice]) + + stream = _generate_and_convert_to_openai_compat() + async for chunk in process_chat_completion_stream_response(stream, request): + yield chunk + + async def _nonstream_chat_completion( + self, request: ChatCompletionRequest, details: ChatDetails + ) -> ChatCompletionResponse: + """ + Perform streaming chat completion using OCI Generative AI. + """ + response = self._get_client().chat(details) + finish_reason = None + message_content = "" + if response.data.choices: + finish_reason = response.data.choices[0].finish_reason + message_content = ( + response.data.choices[0].message.content[0].text if response.data.choices[0].message.content else "" + ) + choice = OpenAICompatCompletionChoice( + finish_reason=finish_reason, + text=message_content, + ) + return process_chat_completion_response(OpenAICompatCompletionResponse(choices=[choice]), request) + + async def _build_openai_chat_details(self, params: dict) -> ChatDetails: + messages = params.get("messages", []) + system_messages = [] + user_messages = [] + structured_messages = [] + + for msg in messages: + if msg.get("role", "") == "system": + system_messages.append(SystemMessage(content=[TextContent(text=msg.get("content", ""))])) + else: + user_messages.append(UserMessage(name="User", content=[TextContent(text=msg.get("content", ""))])) + + structured_messages.extend(system_messages) + structured_messages.extend(user_messages) + + # Create OCI chat request + chat_request = GenericChatRequest( + api_format="GENERIC", + messages=structured_messages, + reasoning_effort=params.get("reasoning_effort"), + verbosity=params.get("verbosity"), + metadata=params.get("metadata"), + is_stream=params.get("stream", False), + stream_options=params.get("stream_options"), + num_generations=params.get("n"), + seed=params.get("seed"), + is_echo=params.get("echo", False), + top_k=params.get("top_k"), + top_p=params.get("top_p"), + temperature=params.get("temperature"), + frequency_penalty=params.get("frequency_penalty"), + presence_penalty=params.get("presence_penalty"), + max_tokens=params.get("max_tokens"), + max_completion_tokens=params.get("max_completion_tokens"), + logit_bias=params.get("logit_bias"), + # log_probs=params.get("log_probs", 0), + # tool_choice=params.get("tool_choice", {}), # Unsupported + # tools=params.get("tools", {}), # Unsupported + # web_search_options=params.get("web_search_options", {}), # Unsupported + # stop=params.get("stop", []), + ) + + model_id = self.get_provider_model_id(params.get("model", "")) + if self.config.oci_serving_mode == OCI_SERVING_MODE_ON_DEMAND: + serving_mode = OnDemandServingMode( + serving_type="ON_DEMAND", + model_id=model_id, + ) + elif self.config.oci_serving_mode == OCI_SERVING_MODE_DEDICATED: + serving_mode = DedicatedServingMode(serving_type="DEDICATED", model_id=model_id) + + chat_details = ChatDetails( + compartment_id=self.config.oci_compartment_id, serving_mode=serving_mode, chat_request=chat_request + ) + return chat_details + + async def openai_chat_completion( + self, + params: OpenAIChatCompletionRequestWithExtraBody, + ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: + if self.model_store is None: + raise ValueError("Model store is not initialized") + model_obj = await self.model_store.get_model(params.model) + request_params = await prepare_openai_completion_params( + model=model_obj.provider_resource_id, + messages=params.messages, + frequency_penalty=params.frequency_penalty, + function_call=params.function_call, + functions=params.functions, + logit_bias=params.logit_bias, + logprobs=params.logprobs, + max_completion_tokens=params.max_completion_tokens, + max_tokens=params.max_tokens, + n=params.n, + parallel_tool_calls=params.parallel_tool_calls, + presence_penalty=params.presence_penalty, + response_format=params.response_format, + seed=params.seed, + stop=params.stop, + stream=params.stream, + stream_options=params.stream_options, + temperature=params.temperature, + tool_choice=params.tool_choice, + tools=params.tools, + top_logprobs=params.top_logprobs, + top_p=params.top_p, + user=params.user, + ) + chat_details = await self._build_openai_chat_details(request_params) + if request_params.get("stream", False): + return self._stream_openai_chat_completion(chat_details) + return await self._nonstream_openai_chat_completion(chat_details) + + async def _nonstream_openai_chat_completion( + self, + chat_details: ChatDetails, + ) -> OpenAIChatCompletion: + """Non-streaming OpenAI chat completion using OCI.""" + response: ChatResult = self._get_client().chat(chat_details) + + choice = OpenAIChoice( + message=OpenAIAssistantMessageParam( + role="assistant", + content=response.data.chat_response.choices[0].message.content[0].text, + ), + finish_reason=response.data.chat_response.choices[0].finish_reason, + index=response.data.chat_response.choices[0].index, + ) + return OpenAIChatCompletion( + id=str(response.data.chat_response.choices[0].index), + choices=[choice], + object="chat.completion", + created=int(response.data.chat_response.time_created.timestamp()), + model=response.data.model_id, + usage=OpenAIChatCompletionUsage( + prompt_tokens=response.data.chat_response.usage.prompt_tokens, + completion_tokens=response.data.chat_response.usage.completion_tokens, + total_tokens=response.data.chat_response.usage.total_tokens, + ), + ) + + async def _stream_openai_chat_completion( + self, + chat_details: ChatDetails, + ) -> AsyncIterator[OpenAIChatCompletionChunk]: + """Streaming OpenAI chat completion using OCI.""" + response = self._get_client().chat(chat_details) + stream = response.data + + i = 0 + for chunk in stream.events(): + i += 1 + + data = json.loads(chunk.data) + finish_reason = data.get("finishReason", "") + message_content = data.get("message", {}).get("content", []) + usage = data.get("usage", None) + + # Extract text content from the message content array + text_content = "" + if message_content: + text_content = message_content[0].get("text", "") + + # Get model_id from the response data + model_id = getattr(response.data, "model_id", None) or chat_details.serving_mode.model_id + + if usage: + final_usage = OpenAIChatCompletionUsage( + prompt_tokens=usage.get("promptTokens", 0), + completion_tokens=usage.get("completionTokens", 0), + total_tokens=usage.get("totalTokens", 0), + ) + else: + final_usage = None + yield OpenAIChatCompletionChunk( + id=f"chunk-{i}", + choices=[ + OpenAIChunkChoice( + delta=OpenAIChoiceDelta(content=text_content), + finish_reason=finish_reason, + index=int(data.get("index", 0)), + ) + ], + object="chat.completion.chunk", + created=int(time.time()), + model=model_id, + usage=final_usage, + ) + + async def openai_completion( + self, + params: OpenAICompletionRequestWithExtraBody, + ) -> OpenAICompletion: + raise NotImplementedError("OpenAI completion is not supported for OCI") + + async def openai_embeddings( + self, + params: OpenAIEmbeddingsRequestWithExtraBody, + ) -> OpenAIEmbeddingsResponse: + raise NotImplementedError("OpenAI embeddings is not supported for OCI") diff --git a/tests/integration/inference/test_openai_completion.py b/tests/integration/inference/test_openai_completion.py index 65f7738892..664d59edab 100644 --- a/tests/integration/inference/test_openai_completion.py +++ b/tests/integration/inference/test_openai_completion.py @@ -52,6 +52,7 @@ def skip_if_model_doesnt_support_openai_completion(client_with_models, model_id) # {"error":{"message":"Unknown request URL: GET /openai/v1/completions. Please check the URL for typos, # or see the docs at https://console.groq.com/docs/","type":"invalid_request_error","code":"unknown_url"}} "remote::groq", + "remote::oci", "remote::gemini", # https://generativelanguage.googleapis.com/v1beta/openai/completions -> 404 "remote::anthropic", # at least claude-3-{5,7}-{haiku,sonnet}-* / claude-{sonnet,opus}-4-* are not supported "remote::azure", # {'error': {'code': 'OperationNotSupported', 'message': 'The completion operation @@ -124,6 +125,7 @@ def skip_if_model_doesnt_support_openai_chat_completion(client_with_models, mode "remote::bedrock", "remote::databricks", "remote::cerebras", + "remote::oci", "remote::runpod", ): pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support OpenAI chat completions.")