Skip to content

Commit 87a579b

Browse files
committed
docs(discoveryengine): Add Search Tuning Samples
1 parent 3f85885 commit 87a579b

File tree

3 files changed

+181
-0
lines changed

3 files changed

+181
-0
lines changed

discoveryengine/search_sample.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,3 +93,56 @@ def search_sample(
9393

9494

9595
# [END genappbuilder_search]
96+
97+
# [START genappbuilder_search_tuned_model]
98+
99+
from google.api_core.client_options import ClientOptions
100+
from google.cloud import discoveryengine_v1alpha as discoveryengine
101+
102+
# TODO(developer): Uncomment these variables before running the sample.
103+
# project_id = "YOUR_PROJECT_ID"
104+
# location = "YOUR_LOCATION" # Values: "global", "us", "eu"
105+
# engine_id = "YOUR_APP_ID"
106+
# search_query = "YOUR_SEARCH_QUERY"
107+
108+
109+
def search_tuned_model_sample(
110+
project_id: str,
111+
location: str,
112+
engine_id: str,
113+
search_query: str,
114+
) -> discoveryengine.services.search_service.pagers.SearchPager:
115+
# For more information, refer to:
116+
# https://cloud.google.com/generative-ai-app-builder/docs/locations#specify_a_multi-region_for_your_data_store
117+
client_options = (
118+
ClientOptions(api_endpoint=f"{location}-discoveryengine.googleapis.com")
119+
if location != "global"
120+
else None
121+
)
122+
123+
# Create a client
124+
client = discoveryengine.SearchServiceClient(client_options=client_options)
125+
126+
# The full resource name of the search app serving config
127+
serving_config = f"projects/{project_id}/locations/{location}/collections/default_collection/engines/{engine_id}/servingConfigs/default_config"
128+
129+
# Refer to the `SearchRequest` reference for all supported fields:
130+
# https://cloud.google.com/python/docs/reference/discoveryengine/latest/google.cloud.discoveryengine_v1.types.SearchRequest
131+
request = discoveryengine.SearchRequest(
132+
serving_config=serving_config,
133+
query=search_query,
134+
custom_fine_tuning_spec=discoveryengine.CustomFineTuningSpec(
135+
enable_search_adaptor=True
136+
),
137+
)
138+
139+
page_result = client.search(request)
140+
141+
# Handle the response
142+
for response in page_result:
143+
print(response)
144+
145+
return page_result
146+
147+
148+
# [END genappbuilder_search_tuned_model]
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
#
15+
16+
# [START genappbuilder_train_custom_model]
17+
18+
from google.api_core.client_options import ClientOptions
19+
from google.api_core.operation import Operation
20+
from google.cloud import discoveryengine
21+
22+
# TODO(developer): Uncomment these variables before running the sample.
23+
# project_id = "YOUR_PROJECT_ID"
24+
# location = "YOUR_LOCATION" # Values: "global"
25+
# data_store_id = "YOUR_DATA_STORE_ID"
26+
# corpus_data_path = "gs://my-bucket/corpus.jsonl"
27+
# query_data_path = "gs://my-bucket/query.jsonl"
28+
# train_data_path = "gs://my-bucket/train.tsv"
29+
# test_data_path = "gs://my-bucket/test.tsv"
30+
31+
32+
def train_custom_model_sample(
33+
project_id: str,
34+
location: str,
35+
data_store_id: str,
36+
corpus_data_path: str,
37+
query_data_path: str,
38+
train_data_path: str,
39+
test_data_path: str,
40+
) -> Operation:
41+
# For more information, refer to:
42+
# https://cloud.google.com/generative-ai-app-builder/docs/locations#specify_a_multi-region_for_your_data_store
43+
client_options = (
44+
ClientOptions(api_endpoint=f"{location}-discoveryengine.googleapis.com")
45+
if location != "global"
46+
else None
47+
)
48+
# Create a client
49+
client = discoveryengine.SearchTuningServiceClient(client_options=client_options)
50+
51+
# The full resource name of the data store
52+
data_store = f"projects/{project_id}/locations/{location}/collections/default_collection/dataStores/{data_store_id}"
53+
54+
# Make the request
55+
operation = client.train_custom_model(
56+
request=discoveryengine.TrainCustomModelRequest(
57+
gcs_training_input=discoveryengine.TrainCustomModelRequest.GcsTrainingInput(
58+
corpus_data_path=corpus_data_path,
59+
query_data_path=query_data_path,
60+
train_data_path=train_data_path,
61+
test_data_path=test_data_path,
62+
),
63+
data_store=data_store,
64+
model_type="search-tuning",
65+
)
66+
)
67+
68+
# Optional: Wait for training to complete
69+
# print(f"Waiting for operation to complete: {operation.operation.name}")
70+
# response = operation.result()
71+
72+
# After the operation is complete,
73+
# get information from operation metadata
74+
# metadata = discoveryengine.TrainCustomModelMetadata(operation.metadata)
75+
76+
# Handle the response
77+
# print(response)
78+
# print(metadata)
79+
print(operation)
80+
81+
return operation
82+
83+
84+
# [END genappbuilder_train_custom_model]
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
#
15+
16+
import os
17+
18+
from discoveryengine import train_custom_model_sample
19+
from google.api_core.exceptions import AlreadyExists
20+
21+
project_id = os.environ["GOOGLE_CLOUD_PROJECT"]
22+
location = "global"
23+
data_store_id = "tuning-data-store"
24+
corpus_data_path = "gs://cloud-samples-data/gen-app-builder/search-tuning/corpus.jsonl"
25+
query_data_path = "gs://cloud-samples-data/gen-app-builder/search-tuning/query.jsonl"
26+
train_data_path = "gs://cloud-samples-data/gen-app-builder/search-tuning/training.tsv"
27+
test_data_path = "gs://cloud-samples-data/gen-app-builder/search-tuning/test.tsv"
28+
29+
30+
def test_train_custom_model():
31+
try:
32+
operation = train_custom_model_sample.train_custom_model_sample(
33+
project_id,
34+
location,
35+
data_store_id,
36+
corpus_data_path,
37+
query_data_path,
38+
train_data_path,
39+
test_data_path,
40+
)
41+
assert operation
42+
except AlreadyExists:
43+
# Ignore AlreadyExists; training is already in progress.
44+
pass

0 commit comments

Comments
 (0)