@@ -114,7 +114,11 @@ def load_tokenizer(model_name_or_path: str) -> AutoTokenizer:
114114
115115
116116def load_model (
117- model_name_or_path : str , lora_only : bool , revision : str , val_args : TrainingArguments
117+ model_name_or_path : str ,
118+ lora_only : bool ,
119+ revision : str ,
120+ val_args : TrainingArguments ,
121+ cached_lora : bool ,
118122) -> Trainer :
119123 logger .info (f"Loading model from base model: { model_name_or_path } " )
120124
@@ -129,7 +133,7 @@ def load_model(
129133 device_map = None ,
130134 )
131135 # check whether it is a lora weight
132- if download_lora_config ( model_name_or_path , revision ) :
136+ if cached_lora :
133137 logger .info ("Repo is a lora weight, loading model with adapter weights" )
134138 with open ("lora/adapter_config.json" , "r" ) as f :
135139 adapter_config = json .load (f )
@@ -295,11 +299,89 @@ def validate(
295299 val_args = parser .parse_json_file (json_file = validation_args_file )[0 ]
296300 gpu_type = get_gpu_type ()
297301
298- tokenizer = load_tokenizer (model_name_or_path )
302+ tokenizer_model_path = model_name_or_path
303+
304+ # Determine the correct tokenizer path, especially for LoRA models
305+ is_lora = download_lora_config (model_name_or_path , revision )
306+ cached_lora = is_lora
307+ adapter_config_path = Path ("lora/adapter_config.json" )
308+
309+ if is_lora :
310+ if adapter_config_path .exists ():
311+ logger .info (
312+ f"Model { model_name_or_path } is a LoRA model. Validating its base model for tokenizer."
313+ )
314+ try :
315+ with open (adapter_config_path , "r" ) as f :
316+ adapter_config = json .load (f )
317+
318+ lora_base_model_path = adapter_config .get ("base_model_name_or_path" )
319+
320+ if (
321+ not lora_base_model_path
322+ ): # Check if base_model_name_or_path is missing
323+ logger .error (
324+ f"LoRA model { model_name_or_path } does not specify 'base_model_name_or_path' "
325+ f"in its adapter_config.json. Marking assignment { assignment_id } as failed."
326+ )
327+ if not local_test :
328+ fed_ledger .mark_assignment_as_failed (assignment_id )
329+ return # Exit validate function
330+
331+ # Check if the extracted base model path is in SUPPORTED_BASE_MODELS
332+ if lora_base_model_path in SUPPORTED_BASE_MODELS :
333+ logger .info (
334+ f"LoRA's base model '{ lora_base_model_path } ' is in SUPPORTED_BASE_MODELS. "
335+ f"Using it for tokenizer."
336+ )
337+ tokenizer_model_path = lora_base_model_path
338+ else :
339+ logger .error (
340+ f"LoRA's base model '{ lora_base_model_path } ' is not in SUPPORTED_BASE_MODELS. "
341+ f"Marking assignment { assignment_id } as failed."
342+ )
343+ if not local_test :
344+ fed_ledger .mark_assignment_as_failed (assignment_id )
345+ return
346+
347+ except json .JSONDecodeError :
348+ logger .error (
349+ f"Failed to decode adapter_config.json for { model_name_or_path } . "
350+ f"Marking assignment { assignment_id } as failed."
351+ )
352+ if not local_test :
353+ fed_ledger .mark_assignment_as_failed (assignment_id )
354+ return
355+ except Exception as e : # Catch any other generic exception during adapter_config processing
356+ logger .error (
357+ f"Error processing adapter_config.json for { model_name_or_path } : { e } . "
358+ f"Marking assignment { assignment_id } as failed."
359+ )
360+ if not local_test :
361+ fed_ledger .mark_assignment_as_failed (assignment_id )
362+ return
363+ else : # is_lora is True, but adapter_config.json does not exist
364+ logger .error (
365+ f"Model { model_name_or_path } is identified as LoRA, but its adapter_config.json was not downloaded or found at { adapter_config_path } . "
366+ f"This could be due to an issue with 'download_lora_config' or the repository structure for the LoRA model. "
367+ f"Marking assignment { assignment_id } as failed."
368+ )
369+ if not local_test :
370+ fed_ledger .mark_assignment_as_failed (assignment_id )
371+ return
372+ else : # Not a LoRA model
373+ logger .info (
374+ f"Model { model_name_or_path } is not identified as a LoRA model. "
375+ f"Using its own path for tokenizer: { model_name_or_path } ."
376+ )
377+
378+ tokenizer = load_tokenizer (tokenizer_model_path )
299379 eval_dataset = load_sft_dataset (
300380 eval_file , context_length , template_name = base_model , tokenizer = tokenizer
301381 )
302- model = load_model (model_name_or_path , lora_only , revision , val_args )
382+ model = load_model (
383+ model_name_or_path , lora_only , revision , val_args , cached_lora
384+ )
303385 # if model is not loaded, mark the assignment as failed and return
304386 if model is None :
305387 fed_ledger .mark_assignment_as_failed (assignment_id )
0 commit comments