Skip to content

Commit 72573d9

Browse files
add preprocessor to setup_response_data for rl training
1 parent d729270 commit 72573d9

File tree

1 file changed

+10
-0
lines changed

1 file changed

+10
-0
lines changed

nemo_rl/data/utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ def setup_response_data(
8484
print("\n▶ Setting up data...")
8585
# setup train dataset
8686
task_data_processors = {}
87+
task_data_preprocessors = {}
8788
task_to_env = {}
8889
data_list = []
8990

@@ -99,6 +100,8 @@ def setup_response_data(
99100
# bind task_name to task_data_processors and task_to_env
100101
task_name = data.task_name
101102
task_data_processors[task_name] = (data.task_spec, data.processor)
103+
if hasattr(data, "preprocessor") and data.preprocessor is not None:
104+
task_data_preprocessors[task_name] = data.preprocessor
102105
if has_envs:
103106
task_to_env[task_name] = envs[cfg["env_name"]]
104107

@@ -108,12 +111,14 @@ def setup_response_data(
108111
tokenizer,
109112
None,
110113
task_data_processors,
114+
task_data_preprocessors=task_data_preprocessors,
111115
max_seq_length=data_config["max_input_seq_length"],
112116
)
113117
print(f" ✓ Training dataset loaded with {len(dataset)} samples.")
114118

115119
# setup validation dataset
116120
val_task_data_processors = {}
121+
val_task_data_preprocessors = {}
117122
val_task_to_env = {}
118123
val_data_list = []
119124

@@ -124,6 +129,8 @@ def setup_response_data(
124129
# bind task_name to task_data_processors and task_to_env
125130
task_name = data.task_name
126131
val_task_data_processors[task_name] = task_data_processors[task_name]
132+
if task_name in task_data_preprocessors:
133+
val_task_data_preprocessors[task_name] = task_data_preprocessors[task_name]
127134
if has_envs:
128135
val_task_to_env[task_name] = task_to_env[task_name]
129136

@@ -144,6 +151,8 @@ def setup_response_data(
144151
val_data.task_spec,
145152
val_data.processor,
146153
)
154+
if hasattr(val_data, "preprocessor") and val_data.preprocessor is not None:
155+
val_task_data_preprocessors[task_name] = val_data.preprocessor
147156
if has_envs:
148157
val_task_to_env[task_name] = envs[cfg["env_name"]]
149158

@@ -155,6 +164,7 @@ def setup_response_data(
155164
tokenizer,
156165
None,
157166
val_task_data_processors,
167+
task_data_preprocessors=val_task_data_preprocessors,
158168
max_seq_length=data_config["max_input_seq_length"],
159169
)
160170
print(f" ✓ Validation dataset loaded with {len(val_dataset)} samples.")

0 commit comments

Comments
 (0)