Skip to content

Commit efefe34

Browse files
author
Googler
committed
feat(components): migrate function_based convert_to_delimited_string to rlhf_preprocessor component
PiperOrigin-RevId: 628282787
1 parent 0c26c04 commit efefe34

File tree

4 files changed

+12
-7
lines changed

4 files changed

+12
-7
lines changed

components/google-cloud/google_cloud_pipeline_components/_implementation/llm/generated/refined_image_versions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,4 @@
1717
DO NOT EDIT - This file is generated, manual changes will be overridden.
1818
"""
1919

20-
IMAGE_TAG = '20240425_1027_RC00'
20+
IMAGE_TAG = '20240425_1734_RC00'

components/google-cloud/google_cloud_pipeline_components/_implementation/llm/reward_model_graph.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from google_cloud_pipeline_components._implementation.llm import preprocess_chat_dataset
2222
from google_cloud_pipeline_components._implementation.llm import private_text_comparison_importer
2323
from google_cloud_pipeline_components._implementation.llm import reward_model_trainer
24+
from google_cloud_pipeline_components._implementation.llm import rlhf_preprocessor
2425
from google_cloud_pipeline_components._implementation.llm import upload_tensorboard_metrics
2526
import kfp
2627

@@ -45,6 +46,7 @@ def pipeline(
4546
accelerator_type: str,
4647
accelerator_count: int,
4748
reward_model_image_uri: str,
49+
comma_separated_candidates_field_names: str,
4850
prompt_sequence_length: int = 512,
4951
target_sequence_length: int = 64,
5052
batch_size: int = 64,
@@ -72,6 +74,7 @@ def pipeline(
7274
accelerator_type: Specific accelerator type for the custom job.
7375
accelerator_count: The number of accelerator.
7476
reward_model_image_uri: Docker image URI to use for the reward model training job.
77+
comma_separated_candidates_field_names: Comma separated list of fields that contain candidate text, e.g. ``'field_1,field_2,field_3'``.
7578
prompt_sequence_length: Maximum tokenized sequence length for input text. Higher values increase memory overhead. This value should be at most 8192. Default value is 512.
7679
target_sequence_length: Maximum tokenized sequence length for target text. Higher values increase memory overhead. This value should be at most 1024. Default value is 64.
7780
batch_size: Number of examples in each finetuning step. Default is 64.
@@ -91,7 +94,6 @@ def pipeline(
9194
"""
9295
# fmt: on
9396
prompt_column = 'input_text'
94-
candidate_columns = ['candidate_0', 'candidate_1']
9597
choice_column = 'choice'
9698

9799
processed_preference_dataset = (
@@ -103,9 +105,6 @@ def pipeline(
103105
).set_display_name('Preprocess Prompt Dataset')
104106
)
105107

106-
comma_separated_candidates_field_names = (
107-
function_based.convert_to_delimited_string(items=candidate_columns)
108-
)
109108
preference_dataset_importer = (
110109
private_text_comparison_importer.private_text_comparison_importer(
111110
project=project,
@@ -114,7 +113,7 @@ def pipeline(
114113
'processed_dataset_uri'
115114
],
116115
inputs_field_name=prompt_column,
117-
comma_separated_candidates_field_names=comma_separated_candidates_field_names.output,
116+
comma_separated_candidates_field_names=comma_separated_candidates_field_names,
118117
choice_field_name=choice_column,
119118
split=env.TRAIN_SPLIT,
120119
large_model_reference=reward_model_reference,
@@ -131,7 +130,7 @@ def pipeline(
131130
location=location,
132131
input_text=eval_dataset,
133132
inputs_field_name=prompt_column,
134-
comma_separated_candidates_field_names=comma_separated_candidates_field_names.output,
133+
comma_separated_candidates_field_names=comma_separated_candidates_field_names,
135134
choice_field_name=choice_column,
136135
split=env.TRAIN_SPLIT,
137136
large_model_reference=reward_model_reference,

components/google-cloud/google_cloud_pipeline_components/_implementation/llm/rlhf_preprocessor.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
"""Component that preprocesses inputs for Reinforcement Learning from Human Feedback (RLHF)."""
1515

1616
import os
17+
from typing import List
1718

1819
from google_cloud_pipeline_components import _placeholders
1920
from google_cloud_pipeline_components import utils as gcpc_utils
@@ -33,6 +34,7 @@ def rlhf_preprocessor(
3334
gcp_resources: dsl.OutputPath(str), # pytype: disable=invalid-annotation
3435
has_tensorboard_id: dsl.OutputPath(bool), # pytype: disable=invalid-annotation
3536
has_inference_dataset: dsl.OutputPath(bool), # pytype: disable=invalid-annotation
37+
metadata_candidate_columns_string: dsl.OutputPath(str), # pytype: disable=invalid-annotation
3638
metadata_large_model_reference: dsl.OutputPath(str), # pytype: disable=invalid-annotation
3739
metadata_reference_model_path: dsl.OutputPath(str), # pytype: disable=invalid-annotation
3840
metadata_reward_model_reference: dsl.OutputPath(str), # pytype: disable=invalid-annotation
@@ -104,6 +106,7 @@ def rlhf_preprocessor(
104106
f'--use_experimental_image={use_experimental_image}',
105107
f'--has_tensorboard_id_path={has_tensorboard_id}',
106108
f'--has_inference_dataset_path={has_inference_dataset}',
109+
f'--metadata_candidate_columns_string_path={metadata_candidate_columns_string}',
107110
f'--metadata_large_model_reference_path={metadata_large_model_reference}',
108111
f'--metadata_reference_model_path_path={metadata_reference_model_path}',
109112
f'--metadata_reward_model_reference_path={metadata_reward_model_reference}',

components/google-cloud/google_cloud_pipeline_components/preview/llm/rlhf/component.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,9 @@ def rlhf_pipeline(
133133
reward_model_image_uri=preprocess_metadata.outputs[
134134
'metadata_refined_image_uri'
135135
],
136+
comma_separated_candidates_field_names=preprocess_metadata.outputs[
137+
'metadata_candidate_columns_string'
138+
],
136139
prompt_sequence_length=prompt_sequence_length,
137140
target_sequence_length=target_sequence_length,
138141
eval_dataset=validate_pipeline_task.outputs[

0 commit comments

Comments
 (0)