Skip to content

Commit d566c21

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: GenAI SDK client - Enabling zero-shot prompt optimization for prompts from Android API
PiperOrigin-RevId: 821852727
1 parent 57d2709 commit d566c21

File tree

5 files changed

+110
-29
lines changed

5 files changed

+110
-29
lines changed

tests/unit/vertexai/genai/test_prompt_optimizer.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,10 +69,28 @@ def test_prompt_optimizer_optimize(self, mock_custom_job, mock_client):
6969
def test_prompt_optimizer_optimize_prompt(
7070
self, mock_custom_optimize_prompt, mock_client
7171
):
72-
"""Test that prompt_optimizer.optimize method creates a custom job."""
72+
"""Test that prompt_optimizer.optimize method calls _custom_optimize_prompt."""
7373
test_client = vertexai.Client(project=_TEST_PROJECT, location=_TEST_LOCATION)
7474
test_client.prompt_optimizer.optimize_prompt(prompt="test_prompt")
7575
mock_client.assert_called_once()
7676
mock_custom_optimize_prompt.assert_called_once()
7777

78+
@mock.patch.object(prompt_optimizer.PromptOptimizer, "_custom_optimize_prompt")
79+
def test_prompt_optimizer_optimize_prompt_with_optimization_target(
80+
self, mock_custom_optimize_prompt
81+
):
82+
"""Test that prompt_optimizer.optimize_prompt method calls _custom_optimize_prompt with optimization_target."""
83+
test_client = vertexai.Client(project=_TEST_PROJECT, location=_TEST_LOCATION)
84+
config = types.OptimizeConfig(
85+
optimization_target=types.OptimizeTarget.GEMINI_NANO,
86+
)
87+
test_client.prompt_optimizer.optimize_prompt(
88+
prompt="test_prompt",
89+
config=config,
90+
)
91+
mock_custom_optimize_prompt.assert_called_once_with(
92+
content=mock.ANY,
93+
config=config,
94+
)
95+
7896
# TODO(b/415060797): add more tests for prompt_optimizer.optimize

vertexai/_genai/_prompt_optimizer_utils.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,40 @@
1515
"""Utility functions for prompt optimizer."""
1616

1717
import json
18+
from google.genai import types as genai_types
1819
from . import types
1920

2021

22+
def _get_parameter_model(
23+
config: types.OptimizeConfigOrDict | None,
24+
content: genai_types.ContentOrDict | None
25+
):
26+
"""Get the parameters from the config and content."""
27+
28+
if isinstance(config, types.OptimizeConfig):
29+
optimization_target = config.optimization_target
30+
elif isinstance(config, dict):
31+
optimization_target = config.get("optimization_target")
32+
else:
33+
optimization_target = None
34+
35+
if optimization_target is None:
36+
return types._OptimizeRequestParameters(
37+
content=content,
38+
config=config,
39+
)
40+
41+
if optimization_target != types.OptimizeTarget.GEMINI_NANO:
42+
raise ValueError(
43+
"Currently, only `gemini_nano` optimization_target is supported."
44+
)
45+
return types._OptimizeRequestParameters(
46+
content=content,
47+
optimization_target="OPTIMIZATION_TARGET_GEMINI_NANO",
48+
config=config,
49+
)
50+
51+
2152
def _get_service_account(
2253
config: types.PromptOptimizerVAPOConfigOrDict,
2354
) -> str:

vertexai/_genai/prompt_optimizer.py

Lines changed: 41 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,13 @@ def _OptimizeRequestParameters_to_vertex(
178178
if getv(from_object, ["config"]) is not None:
179179
setv(to_object, ["config"], getv(from_object, ["config"]))
180180

181+
if getv(from_object, ["optimization_target"]) is not None:
182+
setv(
183+
to_object,
184+
["optimizationTarget"],
185+
getv(from_object, ["optimization_target"]),
186+
)
187+
181188
return to_object
182189

183190

@@ -189,6 +196,7 @@ def _optimize_prompt(
189196
*,
190197
content: Optional[genai_types.ContentOrDict] = None,
191198
config: Optional[types.OptimizeConfigOrDict] = None,
199+
optimization_target: Optional[str] = None,
192200
) -> types.OptimizeResponseEndpoint:
193201
"""
194202
Optimize a single prompt.
@@ -197,6 +205,7 @@ def _optimize_prompt(
197205
parameter_model = types._OptimizeRequestParameters(
198206
content=content,
199207
config=config,
208+
optimization_target=optimization_target,
200209
)
201210

202211
request_url_dict: Optional[dict[str, str]]
@@ -468,7 +477,10 @@ def optimize(
468477
return job
469478

470479
def optimize_prompt(
471-
self, *, prompt: str, config: Optional[types.OptimizeConfig] = None
480+
self,
481+
*,
482+
prompt: str,
483+
config: types.OptimizeConfig | None = None,
472484
) -> types.OptimizeResponse:
473485
"""Makes an API request to _optimize_prompt and returns the parsed response.
474486
@@ -481,24 +493,26 @@ def optimize_prompt(
481493
Args:
482494
prompt: The prompt to optimize.
483495
config: The configuration for prompt optimization. Currently, config is
484-
not supported for a single prompt optimization.
496+
either None or
497+
types.OptimizeConfig(
498+
optimization_target=types.OptimizeTarget.GEMINI_NANO
499+
)
485500
Returns:
486501
The parsed response from the API request.
487502
"""
488-
if config is not None:
489-
raise ValueError(
490-
"Currently, config is not supported for a single prompt optimization."
491-
)
492503

493504
prompt = genai_types.Content(parts=[genai_types.Part(text=prompt)], role="user")
494505
# TODO: b/435653980 - replace the custom method with a generated method.
495-
return self._custom_optimize_prompt(content=prompt)
506+
return self._custom_optimize_prompt(
507+
content=prompt,
508+
config=config,
509+
)
496510

497511
def _custom_optimize_prompt(
498512
self,
499513
*,
500-
content: Optional[genai_types.ContentOrDict] = None,
501-
config: Optional[types.OptimizeConfigOrDict] = None,
514+
content: genai_types.ContentOrDict | None,
515+
config: types.OptimizeConfigOrDict | None,
502516
) -> types.OptimizeResponse:
503517
"""Optimize a single prompt.
504518
@@ -507,11 +521,7 @@ def _custom_optimize_prompt(
507521
the parsed response.
508522
"""
509523

510-
parameter_model = types._OptimizeRequestParameters(
511-
content=content,
512-
config=config,
513-
)
514-
524+
parameter_model = _prompt_optimizer_utils._get_parameter_model(config, content)
515525
request_url_dict: Optional[dict[str, str]]
516526
if not self._api_client.vertexai:
517527
raise ValueError("This method is only supported in the Vertex AI client.")
@@ -576,6 +586,7 @@ async def _optimize_prompt(
576586
*,
577587
content: Optional[genai_types.ContentOrDict] = None,
578588
config: Optional[types.OptimizeConfigOrDict] = None,
589+
optimization_target: Optional[str] = None,
579590
) -> types.OptimizeResponseEndpoint:
580591
"""
581592
Optimize a single prompt.
@@ -584,6 +595,7 @@ async def _optimize_prompt(
584595
parameter_model = types._OptimizeRequestParameters(
585596
content=content,
586597
config=config,
598+
optimization_target=optimization_target,
587599
)
588600

589601
request_url_dict: Optional[dict[str, str]]
@@ -841,16 +853,12 @@ async def optimize(
841853
async def _custom_optimize_prompt(
842854
self,
843855
*,
844-
content: Optional[genai_types.ContentOrDict] = None,
845-
config: Optional[types.OptimizeConfigOrDict] = None,
856+
content: genai_types.ContentOrDict | None,
857+
config: types.OptimizeConfigOrDict | None,
846858
) -> types.OptimizeResponse:
847859
"""Optimize a single prompt."""
848860

849-
parameter_model = types._OptimizeRequestParameters(
850-
content=content,
851-
config=config,
852-
)
853-
861+
parameter_model = _prompt_optimizer_utils._get_parameter_model(config, content)
854862
request_url_dict: Optional[dict[str, str]]
855863
if not self._api_client.vertexai:
856864
raise ValueError("This method is only supported in the Vertex AI client.")
@@ -909,7 +917,10 @@ async def _custom_optimize_prompt(
909917
return final_response
910918

911919
async def optimize_prompt(
912-
self, *, prompt: str, config: Optional[types.OptimizeConfig] = None
920+
self,
921+
*,
922+
prompt: str,
923+
config: types.OptimizeConfig | None = None,
913924
) -> types.OptimizeResponse:
914925
"""Makes an async request to _optimize_prompt and returns an optimized prompt.
915926
@@ -921,15 +932,17 @@ async def optimize_prompt(
921932
Args:
922933
prompt: The prompt to optimize.
923934
config: The configuration for prompt optimization. Currently, config is
924-
not supported for a single prompt optimization.
935+
either None or
936+
types.OptimizeConfig(
937+
optimization_target=types.OptimizeTarget.GEMINI_NANO
938+
)
925939
Returns:
926940
The parsed response from the API request.
927941
"""
928-
if config is not None:
929-
raise ValueError(
930-
"Currently, config is not supported for a single prompt optimization."
931-
)
932942

933943
prompt = genai_types.Content(parts=[genai_types.Part(text=prompt)], role="user")
934944
# TODO: b/435653980 - replace the custom method with a generated method.
935-
return await self._custom_optimize_prompt(content=prompt)
945+
return await self._custom_optimize_prompt(
946+
content=prompt,
947+
config=config,
948+
)

vertexai/_genai/types/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -571,6 +571,7 @@
571571
from .common import OptimizeResponseEndpointDict
572572
from .common import OptimizeResponseEndpointOrDict
573573
from .common import OptimizeResponseOrDict
574+
from .common import OptimizeTarget
574575
from .common import PairwiseChoice
575576
from .common import PairwiseMetricInput
576577
from .common import PairwiseMetricInputDict
@@ -1786,6 +1787,7 @@
17861787
"RubricContentType",
17871788
"EvaluationRunState",
17881789
"Importance",
1790+
"OptimizeTarget",
17891791
"GenerateMemoriesResponseGeneratedMemoryAction",
17901792
"PromptData",
17911793
"PromptDataDict",

vertexai/_genai/types/common.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,13 @@ class Importance(_common.CaseInSensitiveEnum):
343343
"""Low importance."""
344344

345345

346+
class OptimizeTarget(_common.CaseInSensitiveEnum):
347+
"""None"""
348+
349+
GEMINI_NANO = "GEMINI_NANO"
350+
"""The data driven prompt optimizer designer for prompts from Android core API."""
351+
352+
346353
class GenerateMemoriesResponseGeneratedMemoryAction(_common.CaseInSensitiveEnum):
347354
"""The action to take."""
348355

@@ -3998,6 +4005,9 @@ class OptimizeConfig(_common.BaseModel):
39984005
http_options: Optional[genai_types.HttpOptions] = Field(
39994006
default=None, description="""Used to override HTTP request options."""
40004007
)
4008+
optimization_target: Optional[OptimizeTarget] = Field(
4009+
default=None, description=""""""
4010+
)
40014011

40024012

40034013
class OptimizeConfigDict(TypedDict, total=False):
@@ -4006,6 +4016,9 @@ class OptimizeConfigDict(TypedDict, total=False):
40064016
http_options: Optional[genai_types.HttpOptionsDict]
40074017
"""Used to override HTTP request options."""
40084018

4019+
optimization_target: Optional[OptimizeTarget]
4020+
""""""
4021+
40094022

40104023
OptimizeConfigOrDict = Union[OptimizeConfig, OptimizeConfigDict]
40114024

@@ -4015,6 +4028,7 @@ class _OptimizeRequestParameters(_common.BaseModel):
40154028

40164029
content: Optional[genai_types.Content] = Field(default=None, description="""""")
40174030
config: Optional[OptimizeConfig] = Field(default=None, description="""""")
4031+
optimization_target: Optional[str] = Field(default=None, description="""""")
40184032

40194033

40204034
class _OptimizeRequestParametersDict(TypedDict, total=False):
@@ -4026,6 +4040,9 @@ class _OptimizeRequestParametersDict(TypedDict, total=False):
40264040
config: Optional[OptimizeConfigDict]
40274041
""""""
40284042

4043+
optimization_target: Optional[str]
4044+
""""""
4045+
40294046

40304047
_OptimizeRequestParametersOrDict = Union[
40314048
_OptimizeRequestParameters, _OptimizeRequestParametersDict

0 commit comments

Comments
 (0)