Skip to content

Commit 1430035

Browse files
committed
+ add basic reward shaping func
1 parent b52809f commit 1430035

File tree

5 files changed

+106
-23
lines changed

5 files changed

+106
-23
lines changed

trinity/cli/launcher.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import ray
1010

1111
from trinity.common.config import Config, DataPipelineConfig, load_config
12-
from trinity.common.constants import EXPLORER_NAME, TRAINER_NAME
12+
from trinity.common.constants import EXPLORER_NAME, TRAINER_NAME, DataProcessorPipelineType
1313
from trinity.explorer.explorer import Explorer
1414
from trinity.trainer.trainer import Trainer
1515
from trinity.utils.log import get_logger
@@ -126,14 +126,14 @@ def activate_data_module(data_processor_url: str, config_path: str):
126126
return
127127

128128

129-
def validate_data_pipeline(data_pipeline_config: DataPipelineConfig, pipeline_type: str):
129+
def validate_data_pipeline(data_pipeline_config: DataPipelineConfig, pipeline_type: DataProcessorPipelineType):
130130
"""
131131
Check if the data pipeline is valid. The config should:
132132
1. Non-empty input buffer
133133
2. Different input/output buffers
134134
135135
:param data_pipeline_config: the input data pipeline to be validated.
136-
:param pipeline_type: the type of pipeline, should be one of ["task", "experience"]
136+
:param pipeline_type: the type of pipeline, should be one of DataProcessorPipelineType
137137
"""
138138
input_buffers = data_pipeline_config.input_buffers
139139
output_buffer = data_pipeline_config.output_buffer
@@ -147,7 +147,7 @@ def validate_data_pipeline(data_pipeline_config: DataPipelineConfig, pipeline_ty
147147
if output_buffer.name in input_buffer_names:
148148
logger.warning("Output buffer exists in input buffers. Won't activate it.")
149149
return False
150-
if pipeline_type == "task":
150+
if pipeline_type == DataProcessorPipelineType.TASK:
151151
# task pipeline specific
152152
# "raw" field should be True for task pipeline because the data source must be raw data files
153153
for buffer in input_buffers:
@@ -156,12 +156,13 @@ def validate_data_pipeline(data_pipeline_config: DataPipelineConfig, pipeline_ty
156156
'Input buffers should be raw data files for task pipeline ("raw" field should be True). Won\'t activate it.'
157157
)
158158
return False
159-
elif pipeline_type == "experience":
159+
elif pipeline_type == DataProcessorPipelineType.EXPERIENCE:
160160
# experience pipeline specific
161-
raise NotImplementedError("experience_pipeline is not implemented yet.")
161+
# No special items need to be checked.
162+
pass
162163
else:
163164
logger.warning(
164-
f'Invalid pipeline type: {pipeline_type}. Should be one of ["task", "experience"].'
165+
f'Invalid pipeline type: {pipeline_type}..'
165166
)
166167
return False
167168
return True
@@ -177,19 +178,19 @@ def run(config_path: str, dlc: bool = False, plugin_dir: str = None):
177178
if (
178179
data_processor_config.data_processor_url
179180
and data_processor_config.task_pipeline
180-
and validate_data_pipeline(data_processor_config.task_pipeline, "task")
181+
and validate_data_pipeline(data_processor_config.task_pipeline, DataProcessorPipelineType.TASK)
181182
):
182183
activate_data_module(
183-
f"{data_processor_config.data_processor_url}/task_pipeline", config_path
184+
f"{data_processor_config.data_processor_url}/{DataProcessorPipelineType.TASK.value}", config_path
184185
)
185186
# try to activate experience pipeline for experiences
186187
if (
187188
data_processor_config.data_processor_url
188189
and data_processor_config.experience_pipeline
189-
and validate_data_pipeline(data_processor_config.experience_pipeline, "experience")
190+
and validate_data_pipeline(data_processor_config.experience_pipeline, DataProcessorPipelineType.EXPERIENCE)
190191
):
191192
activate_data_module(
192-
f"{data_processor_config.data_processor_url}/experience_pipeline", config_path
193+
f"{data_processor_config.data_processor_url}/{DataProcessorPipelineType.EXPERIENCE.value}", config_path
193194
)
194195
ray_namespace = config.ray_namespace
195196
if dlc:

trinity/common/config.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
StorageType,
1313
SyncMethod,
1414
TaskType,
15+
OpType
1516
)
1617
from trinity.utils.log import get_logger
1718

@@ -100,6 +101,13 @@ class StorageConfig:
100101
# ! DO NOT SET, automatically set corresponding to train/eval
101102
task_type: TaskType = TaskType.EXPLORE
102103

104+
@dataclass
105+
class RewardShapingConfig:
106+
"""Config for reward shaping."""
107+
108+
stats_key: str = ""
109+
op_type: OpType = OpType.ADD
110+
weight: float = 1.0
103111

104112
@dataclass
105113
class DataPipelineConfig:
@@ -125,6 +133,9 @@ class DataPipelineConfig:
125133
priority_weights: Optional[Dict[str, float]] = None
126134
data_dist: Optional[str] = "gaussian" # one of ["gaussian", "uniform"]
127135

136+
# reward shaping related, only available for experience pipeline
137+
reward_shaping: Optional[List[RewardShapingConfig]] = field(default_factory=list)
138+
128139

129140
@dataclass
130141
class DataProcessorConfig:

trinity/common/constants.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,3 +103,17 @@ class RunningStatus(Enum):
103103
RUNNING = "running"
104104
WAITING_SYNC = "waiting_sync"
105105
STOPPED = "stopped"
106+
107+
class DataProcessorPipelineType(Enum):
108+
"""Data processor pipeline type."""
109+
110+
EXPERIENCE = "experience_pipeline"
111+
TASK = "task_pipeline"
112+
113+
class OpType(Enum):
114+
"""Operator type for reward shaping."""
115+
116+
ADD = "add"
117+
SUB = "sub"
118+
MUL = "mul"
119+
DIV = "div"

trinity/data/controllers/active_iterator.py

Lines changed: 68 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
import os
22
import traceback
33
from numbers import Number
4-
from typing import Any, Dict, List
4+
from typing import Any, Dict, List, Optional
5+
from functools import partial
6+
from data_juicer.utils.constant import Fields
57

68
import ray
79

8-
from trinity.common.config import BufferConfig, DataPipelineConfig
10+
from trinity.common.config import BufferConfig, DataPipelineConfig, RewardShapingConfig
11+
from trinity.common.constants import DataProcessorPipelineType, OpType
912
from trinity.data.controllers.default_ops import DIMENSION_STATS_KEYS
1013
from trinity.data.controllers.task_parser import DataTaskParser
1114
from trinity.data.core.dataset import RftDataset
@@ -23,9 +26,24 @@ def __init__(
2326
self,
2427
config: DataPipelineConfig,
2528
buffer_config: BufferConfig,
29+
pipeline_type: Optional[DataProcessorPipelineType, str] = DataProcessorPipelineType.TASK,
2630
):
31+
"""
32+
The initialization method.
33+
34+
:param config: the data pipeline config.
35+
:param buffer_config: the buffer config.
36+
:param pipeline_type: the type of the activated pipeline.
37+
"""
2738
self.config = config
2839
self.buffer_config = buffer_config
40+
self.pipeline_type = pipeline_type
41+
if self.pipeline_type is None:
42+
self.pipeline_type = DataProcessorPipelineType.TASK
43+
if isinstance(self.pipeline_type, str):
44+
self.pipeline_type = DataProcessorPipelineType(pipeline_type)
45+
46+
# check if the llm agent is required
2947
if self.config.agent_model_name is not None and self.config.agent_model_config is not None:
3048
# get the api key
3149
api_key = os.environ.get("OPENAI_API_KEY")
@@ -42,6 +60,8 @@ def __init__(
4260
)
4361
else:
4462
self.llm_agent = None
63+
64+
# init task parser
4565
self.task_parser = DataTaskParser(config, self.llm_agent)
4666

4767
# Priority weights
@@ -153,34 +173,42 @@ def run(self):
153173
traceback.print_exc()
154174
return 6, "Grouping and computing priority score failed."
155175

156-
# step 7. track lineage if they are changed
176+
# step 7. reward shaping. Only available for experience pipeline and the reward shaping config is set
177+
try:
178+
if self.pipeline_type == DataProcessorPipelineType.EXPERIENCE and len(self.config.reward_shaping) > 0:
179+
reshaped_dataset = self._reward_shaping(scored_dataset)
180+
else:
181+
reshaped_dataset = scored_dataset
182+
except Exception:
183+
traceback.print_exc()
184+
return 7, "Reward shaping failed."
185+
186+
# step 8. track lineage if they are changed
157187
try:
158-
res_dataset = scored_dataset
188+
res_dataset = reshaped_dataset
159189
except Exception:
160190
traceback.print_exc()
161-
return 7, "Tracking lineage failed."
191+
return 8, "Tracking lineage failed."
162192

163-
# step 8
193+
# step 9, sort the dataset by the computed priority
164194
try:
165195
if "priority" in res_dataset.data.features:
166196
res_dataset.sort_by("priority", reverse=True)
167197
except Exception:
168198
traceback.print_exc()
169-
return 8, "Sorting results by priority failed."
199+
return 9, "Sorting results by priority failed."
170200

171-
# step 9. sort and export the result to the output buffer
201+
# step 10. sort and export the result to the output buffer
172202
try:
173203
res_dataset.write_to_buffer()
174204
except Exception:
175205
traceback.print_exc()
176-
return 9, "Exporting result to output buffer failed."
206+
return 10, "Exporting result to output buffer failed."
177207

178208
return 0, "success"
179209

180210
def _group_scores(self, dataset: RftDataset) -> RftDataset:
181211
# for perplexity, normalize them with the max value.
182-
from data_juicer.utils.constant import Fields
183-
184212
stats_min_max = {}
185213
for stats in dataset.data.features[Fields.stats]:
186214
all_stats = [
@@ -268,6 +296,35 @@ def _compute_priority_scores(self, dataset: RftDataset) -> RftDataset:
268296
dataset.data = dataset.data.map(self._compute_combined_score)
269297
return dataset
270298

299+
def _reward_shaping_single(self, sample, reward_shaping_config: RewardShapingConfig):
300+
tgt_stats = reward_shaping_config.stats_key
301+
op_type = reward_shaping_config.op_type
302+
# if the target stats does not exist, skip this stats and return the original sample
303+
if tgt_stats not in sample[Fields.stats]:
304+
return sample
305+
if op_type == OpType.ADD:
306+
sample[self.config.format.reward_key] += reward_shaping_config.weight * sample[Fields.stats][tgt_stats]
307+
elif op_type == OpType.MUL:
308+
sample[self.config.format.reward_key] *= reward_shaping_config.weight * sample[Fields.stats][tgt_stats]
309+
elif op_type == OpType.SUB:
310+
sample[self.config.format.reward_key] -= reward_shaping_config.weight * sample[Fields.stats][tgt_stats]
311+
elif op_type == OpType.DIV:
312+
sample[self.config.format.reward_key] /= reward_shaping_config.weight * sample[Fields.stats][tgt_stats]
313+
return sample
314+
315+
def _reward_shaping(self, rft_dataset: RftDataset) -> RftDataset:
316+
dataset = rft_dataset.data
317+
# check if there is a reward column in the dataset. If not, skip!
318+
if self.config.format.reward_key not in dataset.features:
319+
return rft_dataset
320+
# get reward shaping configs
321+
reward_shaping_configs = self.config.reward_shaping
322+
for reward_shaping_config in reward_shaping_configs:
323+
dataset = dataset.map(partial(self._reward_shaping_single, reward_shaping_config=reward_shaping_config))
324+
325+
rft_dataset.data = dataset
326+
return rft_dataset
327+
271328
@ray.method(num_returns=1)
272329
def select_batch(self, dataset: RftDataset, batch_size: int) -> List[Dict[str, Any]]:
273330
"""Select a batch of samples for training"""

trinity/data/server.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def data_processor(pipeline_type):
3333
}
3434
)
3535

36-
iterator = DataActiveIterator(pipeline_config, config.buffer)
36+
iterator = DataActiveIterator(pipeline_config, config.buffer, pipeline_type=pipeline_type)
3737
ret, msg = iterator.run()
3838
return jsonify({"return_code": ret, "message": msg})
3939

0 commit comments

Comments
 (0)