@@ -181,7 +181,7 @@ def verify_with_input_ids(
181181 original_model : ModelWrapper ,
182182 reauthored_model : ReauthoredModelWrapper ,
183183 input_ids : List [int ],
184- kv_cache_max_len : int = 1024 ,
184+ kv_cache_max_len : int = 128 ,
185185 rtol : float = 1e-05 ,
186186 atol : float = 1e-05 ,
187187):
@@ -273,6 +273,8 @@ def verify_reauthored_model(
273273 rtol : float = 1e-05 ,
274274 atol : float = 1e-05 ,
275275 continue_on_failure : bool = False ,
276+ verify_inputs : bool = True ,
277+ verify_prompts : bool = True ,
276278) -> bool :
277279 """Verifies the reauthored model against the original model.
278280
@@ -301,33 +303,37 @@ def verify_reauthored_model(
301303 """
302304 failure_count = 0
303305
304- for input_ids in forward_input_ids :
305- logging .info ("Verifying the reauthored model with input IDs: %s" , input_ids )
306- try :
307- verify_with_input_ids (
308- original_model , reauthored_model , input_ids , rtol = rtol , atol = atol
306+ if verify_inputs :
307+ for input_ids in forward_input_ids :
308+ logging .info (
309+ "Verifying the reauthored model with input IDs: %s" , input_ids
309310 )
310- except AssertionError as e :
311- logging .error ("*** FAILED *** verify with input IDs: %s" , input_ids )
312- failure_count += 1
313- if not continue_on_failure :
314- return False
315- else :
316- logging .info ("*** PASSED *** verify with input IDs: %s" , input_ids )
317-
318- for prompts in generate_prompts :
319- logging .info ("Verifying the reauthored model with prompts: %s" , prompts )
320- try :
321- verify_model_with_prompts (
322- original_model , reauthored_model , tokenizer , prompts , max_new_tokens
323- )
324- except AssertionError as e :
325- logging .error ("*** FAILED *** verify with prompts: %s" , prompts )
326- failure_count += 1
327- if not continue_on_failure :
328- return False
329- else :
330- logging .info ("*** PASSED *** verify with prompts: %s" , prompts )
311+ try :
312+ verify_with_input_ids (
313+ original_model , reauthored_model , input_ids , rtol = rtol , atol = atol
314+ )
315+ except AssertionError as e :
316+ logging .error ("*** FAILED *** verify with input IDs: %s" , input_ids )
317+ failure_count += 1
318+ if not continue_on_failure :
319+ return False
320+ else :
321+ logging .info ("*** PASSED *** verify with input IDs: %s" , input_ids )
322+
323+ if verify_prompts :
324+ for prompts in generate_prompts :
325+ logging .info ("Verifying the reauthored model with prompts: %s" , prompts )
326+ try :
327+ verify_model_with_prompts (
328+ original_model , reauthored_model , tokenizer , prompts , max_new_tokens
329+ )
330+ except AssertionError as e :
331+ logging .error ("*** FAILED *** verify with prompts: %s" , prompts )
332+ failure_count += 1
333+ if not continue_on_failure :
334+ return False
335+ else :
336+ logging .info ("*** PASSED *** verify with prompts: %s" , prompts )
331337
332338 if failure_count == 0 :
333339 logging .info ("*** PASSED *** verify_reauthored_model" )
0 commit comments