8383
8484from flax import nnx
8585
86- from MaxText import globals
86+ from MaxText . globals import MAXTEXT_REPO_ROOT
8787from MaxText import max_logging
8888from MaxText import max_utils
8989from MaxText import pyconfig
122122)
123123# Regex to extract the final numerical answer
124124MATCH_ANSWER = re .compile (rf"{ ANSWER_START } .*?([\d\.\,\$]{{1,}})" , flags = re .MULTILINE | re .DOTALL )
125- CHAT_TEMPLATE_PATH = f" { globals . MAXTEXT_REPO_ROOT } / src/ MaxText/ examples/ chat_templates/ math_qa.json"
125+ CHAT_TEMPLATE_PATH = os . path . join ( MAXTEXT_REPO_ROOT , " src" , " MaxText" , " examples" , " chat_templates" , " math_qa.json")
126126
127127
128128def get_test_dataset (config , tokenizer ):
129+ """Loads and prepares the test dataset from Hugging Face.
130+
131+ Args:
132+ config: The pyconfig object containing run configurations, including
133+ `hf_access_token`.
134+ tokenizer: The tokenizer for processing the text data.
135+
136+ Returns:
137+ A grain.MapDataset instance for the test split, with prompts and target
138+ answers.
139+ """
129140 template_config = instruction_data_processing .load_template_from_file (CHAT_TEMPLATE_PATH )
130141 dataset = datasets .load_dataset (
131142 DATASET_NAME ,
@@ -159,7 +170,17 @@ def get_test_dataset(config, tokenizer):
159170
160171
161172def evaluate_model (dataset , vllm_rollout , debug = True ):
162- """Runs evaluation on the model using vLLM."""
173+ """Runs evaluation on the model using vLLM.
174+
175+ Args:
176+ dataset: The dataset to evaluate on.
177+ vllm_rollout: The vLLM rollout object for generating responses.
178+ debug: If True, prints debug information for each sample.
179+
180+ Returns:
181+ A dictionary containing evaluation scores: 'correct', 'partially_correct',
182+ and 'correct_format' percentages.
183+ """
163184 rollout_config = base_rollout .RolloutConfig (
164185 max_tokens_to_generate = MAX_TOKENS_TO_GENERATE ,
165186 max_prompt_length = MAX_PROMPT_LENGTH ,
@@ -201,12 +222,35 @@ def evaluate_model(dataset, vllm_rollout, debug=True):
201222
202223
203224def safe_string_to_float (text ):
225+ """Cleans a string to make it safely convertible to a float.
226+
227+ Removes commas, spaces, and dollar signs.
228+
229+ Args:
230+ text: The input string.
231+
232+ Returns:
233+ The cleaned string.
234+ """
204235 text = text .replace ("," , "" ).replace (" " , "" ) # converts "2,125" to "2125"
205236 text = text .replace ("$" , "" ) # converts "$50" to "50"
206237 return text
207238
208239
209240def score_response (target , prediction , debug = True ):
241+ """Scores the model's prediction against the target answer.
242+
243+ It checks for exact correctness, partial correctness (within 10%), and
244+ whether the response follows the expected format.
245+
246+ Args:
247+ target: The ground truth answer string.
248+ prediction: The model's generated response string.
249+ debug: If True, prints exceptions during scoring.
250+
251+ Returns:
252+ A tuple of booleans: (is_correct, is_partially_correct, has_correct_format).
253+ """
210254 is_correct , is_partially_correct , has_correct_format = False , False , False
211255 extracted_response = guess .group (1 ) if (guess := MATCH_ANSWER .search (prediction )) is not None else ""
212256 extracted_response = safe_string_to_float (extracted_response )
@@ -231,6 +275,17 @@ def score_response(target, prediction, debug=True):
231275
232276
233277def create_vllm_rollout (config , model , mesh , tokenizer ):
278+ """Creates a vLLM rollout engine for text generation.
279+
280+ Args:
281+ config: The pyconfig object containing run configurations.
282+ model: The NNX model graph.
283+ mesh: The JAX device mesh.
284+ tokenizer: The tokenizer.
285+
286+ Returns:
287+ A VllmRollout instance configured for the model and hardware.
288+ """
234289 tunix_model = TunixMaxTextAdapter (model )
235290 return VllmRollout (
236291 model = tunix_model ,
@@ -245,6 +300,14 @@ def create_vllm_rollout(config, model, mesh, tokenizer):
245300
246301
247302def get_tokenizer (config ):
303+ """Initializes and returns the tokenizer.
304+
305+ Args:
306+ config: The pyconfig object with `tokenizer_path` and `hf_access_token`.
307+
308+ Returns:
309+ A Hugging Face tokenizer instance.
310+ """
248311 tokenizer = transformers .AutoTokenizer .from_pretrained (
249312 config .tokenizer_path ,
250313 token = config .hf_access_token ,
@@ -253,6 +316,11 @@ def get_tokenizer(config):
253316
254317
255318def train_and_evaluate (config ):
319+ """Orchestrates the pre-train evaluation, SFT, and post-train evaluation.
320+
321+ Args:
322+ config: The pyconfig object containing all run configurations.
323+ """
256324 tokenizer = get_tokenizer (config )
257325 test_dataset = get_test_dataset (config , tokenizer )
258326 test_dataset = test_dataset [:NUM_TEST_SAMPLES ]
@@ -261,16 +329,16 @@ def train_and_evaluate(config):
261329 vllm_rollout = create_vllm_rollout (config , trainer .model , mesh , tokenizer )
262330
263331 # 1. Pre-SFT Evaluation
264- max_logging .log (f "Running Pre-SFT evaluation..." )
332+ max_logging .log ("Running Pre-SFT evaluation..." )
265333 score = evaluate_model (test_dataset , vllm_rollout )
266334 print ("Score for PRE-SFT EVALUATION: " , score )
267335
268336 # 2. SFT Training
269- max_logging .log (f "Starting SFT training..." )
337+ max_logging .log ("Starting SFT training..." )
270338 trainer = sft_trainer .train_model (config , trainer , mesh )
271339
272340 # 3. Post-SFT Evaluation
273- max_logging .log (f "Running Post-SFT evaluation..." )
341+ max_logging .log ("Running Post-SFT evaluation..." )
274342 tunix_model = TunixMaxTextAdapter (trainer .model )
275343 state = nnx .state (tunix_model )
276344 vllm_rollout .update_params (state )
0 commit comments