Skip to content

Commit ee28c72

Browse files
author
Googler
committed
feat(components): migrate function_based resolve_num_microbatches to rlhf_preprocessor component
PiperOrigin-RevId: 628226399
1 parent 788531b commit ee28c72

File tree

5 files changed

+13
-9
lines changed

5 files changed

+13
-9
lines changed

components/google-cloud/google_cloud_pipeline_components/_implementation/llm/generated/refined_image_versions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,4 @@
1717
DO NOT EDIT - This file is generated, manual changes will be overridden.
1818
"""
1919

20-
IMAGE_TAG = '20240423_1336'
20+
IMAGE_TAG = '20240425_1027_RC00'

components/google-cloud/google_cloud_pipeline_components/_implementation/llm/reinforcement_learning_graph.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ def pipeline(
6262
location: str = _placeholders.LOCATION_PLACEHOLDER,
6363
tensorboard_resource_id: str = '',
6464
encryption_spec_key_name: str = '',
65+
num_microbatches: int = 0,
6566
) -> PipelineOutput:
6667
# fmt: off
6768
"""Trains a reward model.
@@ -122,9 +123,6 @@ def pipeline(
122123
.set_display_name('Import Prompt Dataset')
123124
.set_caching_options(False)
124125
)
125-
num_microbatches = function_based.resolve_num_microbatches(
126-
large_model_reference=policy_model_reference,
127-
).set_display_name('Resolve Number of Microbatches')
128126
rl_model = (
129127
reinforcer.reinforcer(
130128
project=project,
@@ -150,7 +148,7 @@ def pipeline(
150148
kl_coeff=kl_coeff,
151149
lora_dim=lora_dim,
152150
reward_lora_dim=reward_lora_dim,
153-
num_microbatches=num_microbatches.output,
151+
num_microbatches=num_microbatches,
154152
encryption_spec_key_name=encryption_spec_key_name,
155153
tensorboard_resource_id=tensorboard_resource_id,
156154
)

components/google-cloud/google_cloud_pipeline_components/_implementation/llm/reward_model_graph.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ def pipeline(
5757
location: str = _placeholders.LOCATION_PLACEHOLDER,
5858
tensorboard_resource_id: str = '',
5959
encryption_spec_key_name: str = '',
60+
num_microbatches: int = 0,
6061
) -> PipelineOutput:
6162
# fmt: off
6263
"""Trains a reward model.
@@ -82,6 +83,7 @@ def pipeline(
8283
location: Location used to run non-tuning components, i.e. components that do not require accelerators. If not specified the location used to run the pipeline will be used.
8384
tensorboard_resource_id: Optional tensorboard resource id in format `projects/{project_number}/locations/{location}/tensorboards/{tensorboard_id}`. If provided, tensorboard metrics will be uploaded to this location.
8485
encryption_spec_key_name: Customer-managed encryption key. If this is set, then all resources created by the CustomJob will be encrypted with the provided encryption key. Note that this is not supported for TPU at the moment.
86+
num_microbatches: The number of microbatches to break the total batch size into during training.
8587
8688
Returns:
8789
reward_model_adapter_path: Path to the output LoRA adapter.
@@ -140,9 +142,6 @@ def pipeline(
140142
.set_caching_options(False)
141143
)
142144

143-
num_microbatches = function_based.resolve_num_microbatches(
144-
large_model_reference=reward_model_reference,
145-
).set_display_name('Resolve Number of Microbatches')
146145
reward_model = (
147146
reward_model_trainer.reward_model_trainer(
148147
project=project,
@@ -165,7 +164,7 @@ def pipeline(
165164
batch_size=batch_size,
166165
learning_rate_multiplier=reward_model_learning_rate_multiplier,
167166
lora_dim=lora_dim,
168-
num_microbatches=num_microbatches.output,
167+
num_microbatches=num_microbatches,
169168
encryption_spec_key_name=encryption_spec_key_name,
170169
tensorboard_resource_id=tensorboard_resource_id,
171170
)

components/google-cloud/google_cloud_pipeline_components/_implementation/llm/rlhf_preprocessor.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ def rlhf_preprocessor(
4242
metadata_accelerator_type: dsl.OutputPath(str), # pytype: disable=invalid-annotation
4343
metadata_accelerator_count: dsl.OutputPath(int), # pytype: disable=invalid-annotation
4444
metadata_refined_image_uri: dsl.OutputPath(str), # pytype: disable=invalid-annotation
45+
metadata_num_microbatches: dsl.OutputPath(int), # pytype: disable=invalid-annotation
4546
use_experimental_image: bool = False,
4647
evaluation_dataset: str = '',
4748
tensorboard_resource_id: str = '',
@@ -77,6 +78,8 @@ def rlhf_preprocessor(
7778
metadata_accelerator_type: Specific accelerator type for the custom job.
7879
metadata_accelerator_count: The number of accelerator.
7980
metadata_refined_image_uri: Docker image URI to use for the custom job.
81+
metadata_num_microbatches: Number of microbatches to break the total batch
82+
size into during training.
8083
"""
8184
# fmt: on
8285
return gcpc_utils.build_serverless_customjob_container_spec(
@@ -110,6 +113,7 @@ def rlhf_preprocessor(
110113
f'--metadata_accelerator_type_path={metadata_accelerator_type}',
111114
f'--metadata_accelerator_count_path={metadata_accelerator_count}',
112115
f'--metadata_refined_image_uri_path={metadata_refined_image_uri}',
116+
f'--metadata_num_microbatches_path={metadata_num_microbatches}',
113117
],
114118
),
115119
gcp_resources=gcp_resources,

components/google-cloud/google_cloud_pipeline_components/preview/llm/rlhf/component.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ def rlhf_pipeline(
107107
evaluation_dataset=eval_dataset,
108108
tensorboard_resource_id=tensorboard_resource_id,
109109
).set_display_name('Preprocess Inputs')
110+
num_microbatches = preprocess_metadata.outputs['metadata_num_microbatches']
110111

111112
reward_model_pipeline = (
112113
(
@@ -145,6 +146,7 @@ def rlhf_pipeline(
145146
location=location,
146147
tensorboard_resource_id=tensorboard_resource_id,
147148
encryption_spec_key_name=encryption_spec_key_name,
149+
num_microbatches=num_microbatches,
148150
)
149151
)
150152
.set_display_name('Train Reward Model')
@@ -189,6 +191,7 @@ def rlhf_pipeline(
189191
location=location,
190192
tensorboard_resource_id=tensorboard_resource_id,
191193
encryption_spec_key_name=encryption_spec_key_name,
194+
num_microbatches=num_microbatches,
192195
).set_display_name('Reinforcement Learning')
193196

194197
has_inference_dataset = preprocess_metadata.outputs['has_inference_dataset']

0 commit comments

Comments
 (0)