@@ -84,6 +84,7 @@ def setup_response_data(
8484 print ("\n ▶ Setting up data..." )
8585 # setup train dataset
8686 task_data_processors = {}
87+ task_data_preprocessors = {}
8788 task_to_env = {}
8889 data_list = []
8990
@@ -99,6 +100,8 @@ def setup_response_data(
99100 # bind task_name to task_data_processors and task_to_env
100101 task_name = data .task_name
101102 task_data_processors [task_name ] = (data .task_spec , data .processor )
103+ if hasattr (data , "preprocessor" ) and data .preprocessor is not None :
104+ task_data_preprocessors [task_name ] = data .preprocessor
102105 if has_envs :
103106 task_to_env [task_name ] = envs [cfg ["env_name" ]]
104107
@@ -108,12 +111,14 @@ def setup_response_data(
108111 tokenizer ,
109112 None ,
110113 task_data_processors ,
114+ task_data_preprocessors = task_data_preprocessors ,
111115 max_seq_length = data_config ["max_input_seq_length" ],
112116 )
113117 print (f" ✓ Training dataset loaded with { len (dataset )} samples." )
114118
115119 # setup validation dataset
116120 val_task_data_processors = {}
121+ val_task_data_preprocessors = {}
117122 val_task_to_env = {}
118123 val_data_list = []
119124
@@ -124,6 +129,8 @@ def setup_response_data(
124129 # bind task_name to task_data_processors and task_to_env
125130 task_name = data .task_name
126131 val_task_data_processors [task_name ] = task_data_processors [task_name ]
132+ if task_name in task_data_preprocessors :
133+ val_task_data_preprocessors [task_name ] = task_data_preprocessors [task_name ]
127134 if has_envs :
128135 val_task_to_env [task_name ] = task_to_env [task_name ]
129136
@@ -144,6 +151,8 @@ def setup_response_data(
144151 val_data .task_spec ,
145152 val_data .processor ,
146153 )
154+ if hasattr (val_data , "preprocessor" ) and val_data .preprocessor is not None :
155+ val_task_data_preprocessors [task_name ] = val_data .preprocessor
147156 if has_envs :
148157 val_task_to_env [task_name ] = envs [cfg ["env_name" ]]
149158
@@ -155,6 +164,7 @@ def setup_response_data(
155164 tokenizer ,
156165 None ,
157166 val_task_data_processors ,
167+ task_data_preprocessors = val_task_data_preprocessors ,
158168 max_seq_length = data_config ["max_input_seq_length" ],
159169 )
160170 print (f" ✓ Validation dataset loaded with { len (val_dataset )} samples." )
0 commit comments