Skip to content

Commit a4bad6c

Browse files
feat: Add Google Vertex AI inference provider support (llamastack#2841)
# What does this PR do? - Add new Vertex AI remote inference provider with litellm integration - Support for Gemini models through Google Cloud Vertex AI platform - Uses Google Cloud Application Default Credentials (ADC) for authentication - Added VertexAI models: gemini-2.5-flash, gemini-2.5-pro, gemini-2.0-flash. - Updated provider registry to include vertexai provider - Updated starter template to support Vertex AI configuration - Added comprehensive documentation and sample configuration <!-- If resolving an issue, uncomment and update the line below --> relates to llamastack#2747 ## Test Plan <!-- Describe the tests you ran to verify your changes with result summaries. *Provide clear instructions so the plan can be easily re-executed.* --> Signed-off-by: Eran Cohen <[email protected]> Co-authored-by: Francisco Arceo <[email protected]>
1 parent 78a59a4 commit a4bad6c

File tree

14 files changed

+227
-0
lines changed

14 files changed

+227
-0
lines changed

docs/source/providers/inference/index.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ remote_runpod
2929
remote_sambanova
3030
remote_tgi
3131
remote_together
32+
remote_vertexai
3233
remote_vllm
3334
remote_watsonx
3435
```
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# remote::vertexai
2+
3+
## Description
4+
5+
Google Vertex AI inference provider enables you to use Google's Gemini models through Google Cloud's Vertex AI platform, providing several advantages:
6+
7+
• Enterprise-grade security: Uses Google Cloud's security controls and IAM
8+
• Better integration: Seamless integration with other Google Cloud services
9+
• Advanced features: Access to additional Vertex AI features like model tuning and monitoring
10+
• Authentication: Uses Google Cloud Application Default Credentials (ADC) instead of API keys
11+
12+
Configuration:
13+
- Set VERTEX_AI_PROJECT environment variable (required)
14+
- Set VERTEX_AI_LOCATION environment variable (optional, defaults to us-central1)
15+
- Use Google Cloud Application Default Credentials or service account key
16+
17+
Authentication Setup:
18+
Option 1 (Recommended): gcloud auth application-default login
19+
Option 2: Set GOOGLE_APPLICATION_CREDENTIALS to service account key path
20+
21+
Available Models:
22+
- vertex_ai/gemini-2.0-flash
23+
- vertex_ai/gemini-2.5-flash
24+
- vertex_ai/gemini-2.5-pro
25+
26+
## Configuration
27+
28+
| Field | Type | Required | Default | Description |
29+
|-------|------|----------|---------|-------------|
30+
| `project` | `<class 'str'>` | No | | Google Cloud project ID for Vertex AI |
31+
| `location` | `<class 'str'>` | No | us-central1 | Google Cloud location for Vertex AI |
32+
33+
## Sample Configuration
34+
35+
```yaml
36+
project: ${env.VERTEX_AI_PROJECT:=}
37+
location: ${env.VERTEX_AI_LOCATION:=us-central1}
38+
39+
```
40+

llama_stack/distributions/ci-tests/build.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ distribution_spec:
1414
- provider_type: remote::openai
1515
- provider_type: remote::anthropic
1616
- provider_type: remote::gemini
17+
- provider_type: remote::vertexai
1718
- provider_type: remote::groq
1819
- provider_type: remote::sambanova
1920
- provider_type: inline::sentence-transformers

llama_stack/distributions/ci-tests/run.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,11 @@ providers:
6565
provider_type: remote::gemini
6666
config:
6767
api_key: ${env.GEMINI_API_KEY:=}
68+
- provider_id: ${env.VERTEX_AI_PROJECT:+vertexai}
69+
provider_type: remote::vertexai
70+
config:
71+
project: ${env.VERTEX_AI_PROJECT:=}
72+
location: ${env.VERTEX_AI_LOCATION:=us-central1}
6873
- provider_id: groq
6974
provider_type: remote::groq
7075
config:

llama_stack/distributions/starter/build.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ distribution_spec:
1414
- provider_type: remote::openai
1515
- provider_type: remote::anthropic
1616
- provider_type: remote::gemini
17+
- provider_type: remote::vertexai
1718
- provider_type: remote::groq
1819
- provider_type: remote::sambanova
1920
- provider_type: inline::sentence-transformers

llama_stack/distributions/starter/run.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,11 @@ providers:
6565
provider_type: remote::gemini
6666
config:
6767
api_key: ${env.GEMINI_API_KEY:=}
68+
- provider_id: ${env.VERTEX_AI_PROJECT:+vertexai}
69+
provider_type: remote::vertexai
70+
config:
71+
project: ${env.VERTEX_AI_PROJECT:=}
72+
location: ${env.VERTEX_AI_LOCATION:=us-central1}
6873
- provider_id: groq
6974
provider_type: remote::groq
7075
config:

llama_stack/distributions/starter/starter.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ def _get_config_for_provider(provider_spec: ProviderSpec) -> dict[str, Any]:
5656
"fireworks",
5757
"together",
5858
"gemini",
59+
"vertexai",
5960
"groq",
6061
"sambanova",
6162
"anthropic",
@@ -71,6 +72,7 @@ def _get_config_for_provider(provider_spec: ProviderSpec) -> dict[str, Any]:
7172
"tgi": "${env.TGI_URL:+tgi}",
7273
"cerebras": "${env.CEREBRAS_API_KEY:+cerebras}",
7374
"nvidia": "${env.NVIDIA_API_KEY:+nvidia}",
75+
"vertexai": "${env.VERTEX_AI_PROJECT:+vertexai}",
7476
}
7577

7678

@@ -246,6 +248,14 @@ def get_distribution_template() -> DistributionTemplate:
246248
"",
247249
"Gemini API Key",
248250
),
251+
"VERTEX_AI_PROJECT": (
252+
"",
253+
"Google Cloud Project ID for Vertex AI",
254+
),
255+
"VERTEX_AI_LOCATION": (
256+
"us-central1",
257+
"Google Cloud Location for Vertex AI",
258+
),
249259
"SAMBANOVA_API_KEY": (
250260
"",
251261
"SambaNova API Key",

llama_stack/providers/registry/inference.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,36 @@ def available_providers() -> list[ProviderSpec]:
213213
description="Google Gemini inference provider for accessing Gemini models and Google's AI services.",
214214
),
215215
),
216+
remote_provider_spec(
217+
api=Api.inference,
218+
adapter=AdapterSpec(
219+
adapter_type="vertexai",
220+
pip_packages=["litellm", "google-cloud-aiplatform"],
221+
module="llama_stack.providers.remote.inference.vertexai",
222+
config_class="llama_stack.providers.remote.inference.vertexai.VertexAIConfig",
223+
provider_data_validator="llama_stack.providers.remote.inference.vertexai.config.VertexAIProviderDataValidator",
224+
description="""Google Vertex AI inference provider enables you to use Google's Gemini models through Google Cloud's Vertex AI platform, providing several advantages:
225+
226+
• Enterprise-grade security: Uses Google Cloud's security controls and IAM
227+
• Better integration: Seamless integration with other Google Cloud services
228+
• Advanced features: Access to additional Vertex AI features like model tuning and monitoring
229+
• Authentication: Uses Google Cloud Application Default Credentials (ADC) instead of API keys
230+
231+
Configuration:
232+
- Set VERTEX_AI_PROJECT environment variable (required)
233+
- Set VERTEX_AI_LOCATION environment variable (optional, defaults to us-central1)
234+
- Use Google Cloud Application Default Credentials or service account key
235+
236+
Authentication Setup:
237+
Option 1 (Recommended): gcloud auth application-default login
238+
Option 2: Set GOOGLE_APPLICATION_CREDENTIALS to service account key path
239+
240+
Available Models:
241+
- vertex_ai/gemini-2.0-flash
242+
- vertex_ai/gemini-2.5-flash
243+
- vertex_ai/gemini-2.5-pro""",
244+
),
245+
),
216246
remote_provider_spec(
217247
api=Api.inference,
218248
adapter=AdapterSpec(
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the terms described in the LICENSE file in
5+
# the root directory of this source tree.
6+
7+
from .config import VertexAIConfig
8+
9+
10+
async def get_adapter_impl(config: VertexAIConfig, _deps):
11+
from .vertexai import VertexAIInferenceAdapter
12+
13+
impl = VertexAIInferenceAdapter(config)
14+
await impl.initialize()
15+
return impl
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the terms described in the LICENSE file in
5+
# the root directory of this source tree.
6+
7+
from typing import Any
8+
9+
from pydantic import BaseModel, Field
10+
11+
from llama_stack.schema_utils import json_schema_type
12+
13+
14+
class VertexAIProviderDataValidator(BaseModel):
15+
vertex_project: str | None = Field(
16+
default=None,
17+
description="Google Cloud project ID for Vertex AI",
18+
)
19+
vertex_location: str | None = Field(
20+
default=None,
21+
description="Google Cloud location for Vertex AI (e.g., us-central1)",
22+
)
23+
24+
25+
@json_schema_type
26+
class VertexAIConfig(BaseModel):
27+
project: str = Field(
28+
description="Google Cloud project ID for Vertex AI",
29+
)
30+
location: str = Field(
31+
default="us-central1",
32+
description="Google Cloud location for Vertex AI",
33+
)
34+
35+
@classmethod
36+
def sample_run_config(
37+
cls,
38+
project: str = "${env.VERTEX_AI_PROJECT:=}",
39+
location: str = "${env.VERTEX_AI_LOCATION:=us-central1}",
40+
**kwargs,
41+
) -> dict[str, Any]:
42+
return {
43+
"project": project,
44+
"location": location,
45+
}

0 commit comments

Comments
 (0)