Skip to content

Commit f175c71

Browse files
author
Googler
committed
feat(components): migrate function_based resolve_regional_endpoint to rlhf_preprocessor component
PiperOrigin-RevId: 629315370
1 parent 401aac7 commit f175c71

File tree

4 files changed

+13
-7
lines changed

4 files changed

+13
-7
lines changed

components/google-cloud/google_cloud_pipeline_components/_implementation/llm/deployment_graph.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ def pipeline(
3939
deploy_model: bool = True,
4040
encryption_spec_key_name: str = '',
4141
upload_location: str = _placeholders.LOCATION_PLACEHOLDER,
42+
regional_endpoint: str = '',
4243
) -> PipelineOutput:
4344
# fmt: off
4445
"""Uploads a tuned language model and (optionally) deploys it to an endpoint.
@@ -51,16 +52,13 @@ def pipeline(
5152
deploy_model: Whether to deploy the model to an endpoint in `us-central1`. Default is True.
5253
encryption_spec_key_name: Customer-managed encryption key. If this is set, then all resources created by the CustomJob will be encrypted with the provided encryption key. Note that this is not supported for TPU at the moment.
5354
upload_location: Region to upload and deploy the model to. Default is the location used to run the pipeline components.
55+
regional_endpoint: Regional endpoint to upload the model.
5456
5557
Returns:
5658
model_resource_name: Path to the model uploaded to the Model Registry. This will be an empty string if the model was not deployed.
5759
endpoint_resource_name: Path the Online Prediction Endpoint. This will be an empty string if the model was not deployed.
5860
"""
5961
# fmt: on
60-
regional_endpoint = function_based.resolve_regional_endpoint(
61-
upload_location=upload_location
62-
).set_display_name('Resolve Regional Endpoint')
63-
6462
display_name = (
6563
function_based.resolve_model_display_name(
6664
large_model_reference=large_model_reference,
@@ -76,7 +74,7 @@ def pipeline(
7674
upload_task = upload_llm_model.refined_upload_llm_model(
7775
project=_placeholders.PROJECT_ID_PLACEHOLDER,
7876
location=upload_location,
79-
regional_endpoint=regional_endpoint.output,
77+
regional_endpoint=regional_endpoint,
8078
artifact_uri=output_adapter_path,
8179
model_display_name=display_name.output,
8280
model_reference_name=large_model_reference,
@@ -93,7 +91,7 @@ def pipeline(
9391
location=upload_location,
9492
model_resource_name=upload_task.outputs['model_resource_name'],
9593
display_name=display_name.output,
96-
regional_endpoint=regional_endpoint.output,
94+
regional_endpoint=regional_endpoint,
9795
deploy_model=deploy_model.output,
9896
encryption_spec_key_name=encryption_spec_key_name,
9997
).set_display_name('Deploy Model')

components/google-cloud/google_cloud_pipeline_components/_implementation/llm/generated/refined_image_versions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,4 @@
1717
DO NOT EDIT - This file is generated, manual changes will be overridden.
1818
"""
1919

20-
IMAGE_TAG = '20240428_1707'
20+
IMAGE_TAG = '20240429_1553_RC00'

components/google-cloud/google_cloud_pipeline_components/_implementation/llm/rlhf_preprocessor.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,13 @@ def rlhf_preprocessor(
4545
metadata_accelerator_count: dsl.OutputPath(int), # pytype: disable=invalid-annotation
4646
metadata_refined_image_uri: dsl.OutputPath(str), # pytype: disable=invalid-annotation
4747
metadata_num_microbatches: dsl.OutputPath(int), # pytype: disable=invalid-annotation
48+
metadata_upload_location: dsl.OutputPath(str), # pytype: disable=invalid-annotation
4849
use_experimental_image: bool = False,
4950
evaluation_dataset: str = '',
5051
tensorboard_resource_id: str = '',
5152
input_reference_model_path: str = '',
5253
image_uri: str = utils.get_default_image_uri('refined_cpu', ''),
54+
upload_location: str = '',
5355
) -> dsl.ContainerSpec: # pylint: disable=g-doc-args
5456
# fmt: off
5557
"""Preprocess RLHF pipeline inputs.
@@ -70,6 +72,7 @@ def rlhf_preprocessor(
7072
metadata_reward_model_reference: The base model for training reward model. The name should be in capitalized snake case format.
7173
metadata_reward_model_path: The model checkpoint path for the reward model.
7274
image_uri: Docker image URI to use for the custom job.
75+
upload_location: Region where the model will be uploaded.
7376
7477
Returns:
7578
gcp_resources: GCP resources that can be used to track the custom job.
@@ -82,6 +85,7 @@ def rlhf_preprocessor(
8285
metadata_refined_image_uri: Docker image URI to use for the custom job.
8386
metadata_num_microbatches: Number of microbatches to break the total batch
8487
size into during training.
88+
metadata_upload_location: Regional endpoint.
8589
"""
8690
# fmt: on
8791
return gcpc_utils.build_serverless_customjob_container_spec(
@@ -104,6 +108,7 @@ def rlhf_preprocessor(
104108
f'--artifact_registry={artifact_registry}',
105109
f'--tag={tag}',
106110
f'--use_experimental_image={use_experimental_image}',
111+
f'--upload_location={upload_location}',
107112
f'--has_tensorboard_id_path={has_tensorboard_id}',
108113
f'--has_inference_dataset_path={has_inference_dataset}',
109114
f'--metadata_candidate_columns_string_path={metadata_candidate_columns_string}',
@@ -117,6 +122,7 @@ def rlhf_preprocessor(
117122
f'--metadata_accelerator_count_path={metadata_accelerator_count}',
118123
f'--metadata_refined_image_uri_path={metadata_refined_image_uri}',
119124
f'--metadata_num_microbatches_path={metadata_num_microbatches}',
125+
f'--metadata_upload_location_path={metadata_upload_location}',
120126
],
121127
),
122128
gcp_resources=gcp_resources,

components/google-cloud/google_cloud_pipeline_components/preview/llm/rlhf/component.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ def rlhf_pipeline(
106106
tag=env.get_private_image_tag(),
107107
evaluation_dataset=eval_dataset,
108108
tensorboard_resource_id=tensorboard_resource_id,
109+
upload_location=location,
109110
).set_display_name('Preprocess Inputs')
110111
num_microbatches = preprocess_metadata.outputs['metadata_num_microbatches']
111112

@@ -233,6 +234,7 @@ def rlhf_pipeline(
233234
deploy_model=deploy_model,
234235
encryption_spec_key_name=encryption_spec_key_name,
235236
upload_location=location,
237+
regional_endpoint=preprocess_metadata.outputs['metadata_upload_location'],
236238
).set_display_name('Upload and Deploy Tuned Model')
237239

238240
return PipelineOutput(

0 commit comments

Comments
 (0)