Skip to content

Commit 2d878d1

Browse files
authored
fix: Enforce tokenizer from approved LoRA base models (#81)
1 parent fef90b5 commit 2d878d1

File tree

1 file changed

+86
-4
lines changed

1 file changed

+86
-4
lines changed

src/validate.py

Lines changed: 86 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,11 @@ def load_tokenizer(model_name_or_path: str) -> AutoTokenizer:
114114

115115

116116
def load_model(
117-
model_name_or_path: str, lora_only: bool, revision: str, val_args: TrainingArguments
117+
model_name_or_path: str,
118+
lora_only: bool,
119+
revision: str,
120+
val_args: TrainingArguments,
121+
cached_lora: bool,
118122
) -> Trainer:
119123
logger.info(f"Loading model from base model: {model_name_or_path}")
120124

@@ -129,7 +133,7 @@ def load_model(
129133
device_map=None,
130134
)
131135
# check whether it is a lora weight
132-
if download_lora_config(model_name_or_path, revision):
136+
if cached_lora:
133137
logger.info("Repo is a lora weight, loading model with adapter weights")
134138
with open("lora/adapter_config.json", "r") as f:
135139
adapter_config = json.load(f)
@@ -295,11 +299,89 @@ def validate(
295299
val_args = parser.parse_json_file(json_file=validation_args_file)[0]
296300
gpu_type = get_gpu_type()
297301

298-
tokenizer = load_tokenizer(model_name_or_path)
302+
tokenizer_model_path = model_name_or_path
303+
304+
# Determine the correct tokenizer path, especially for LoRA models
305+
is_lora = download_lora_config(model_name_or_path, revision)
306+
cached_lora = is_lora
307+
adapter_config_path = Path("lora/adapter_config.json")
308+
309+
if is_lora:
310+
if adapter_config_path.exists():
311+
logger.info(
312+
f"Model {model_name_or_path} is a LoRA model. Validating its base model for tokenizer."
313+
)
314+
try:
315+
with open(adapter_config_path, "r") as f:
316+
adapter_config = json.load(f)
317+
318+
lora_base_model_path = adapter_config.get("base_model_name_or_path")
319+
320+
if (
321+
not lora_base_model_path
322+
): # Check if base_model_name_or_path is missing
323+
logger.error(
324+
f"LoRA model {model_name_or_path} does not specify 'base_model_name_or_path' "
325+
f"in its adapter_config.json. Marking assignment {assignment_id} as failed."
326+
)
327+
if not local_test:
328+
fed_ledger.mark_assignment_as_failed(assignment_id)
329+
return # Exit validate function
330+
331+
# Check if the extracted base model path is in SUPPORTED_BASE_MODELS
332+
if lora_base_model_path in SUPPORTED_BASE_MODELS:
333+
logger.info(
334+
f"LoRA's base model '{lora_base_model_path}' is in SUPPORTED_BASE_MODELS. "
335+
f"Using it for tokenizer."
336+
)
337+
tokenizer_model_path = lora_base_model_path
338+
else:
339+
logger.error(
340+
f"LoRA's base model '{lora_base_model_path}' is not in SUPPORTED_BASE_MODELS. "
341+
f"Marking assignment {assignment_id} as failed."
342+
)
343+
if not local_test:
344+
fed_ledger.mark_assignment_as_failed(assignment_id)
345+
return
346+
347+
except json.JSONDecodeError:
348+
logger.error(
349+
f"Failed to decode adapter_config.json for {model_name_or_path}. "
350+
f"Marking assignment {assignment_id} as failed."
351+
)
352+
if not local_test:
353+
fed_ledger.mark_assignment_as_failed(assignment_id)
354+
return
355+
except Exception as e: # Catch any other generic exception during adapter_config processing
356+
logger.error(
357+
f"Error processing adapter_config.json for {model_name_or_path}: {e}. "
358+
f"Marking assignment {assignment_id} as failed."
359+
)
360+
if not local_test:
361+
fed_ledger.mark_assignment_as_failed(assignment_id)
362+
return
363+
else: # is_lora is True, but adapter_config.json does not exist
364+
logger.error(
365+
f"Model {model_name_or_path} is identified as LoRA, but its adapter_config.json was not downloaded or found at {adapter_config_path}. "
366+
f"This could be due to an issue with 'download_lora_config' or the repository structure for the LoRA model. "
367+
f"Marking assignment {assignment_id} as failed."
368+
)
369+
if not local_test:
370+
fed_ledger.mark_assignment_as_failed(assignment_id)
371+
return
372+
else: # Not a LoRA model
373+
logger.info(
374+
f"Model {model_name_or_path} is not identified as a LoRA model. "
375+
f"Using its own path for tokenizer: {model_name_or_path}."
376+
)
377+
378+
tokenizer = load_tokenizer(tokenizer_model_path)
299379
eval_dataset = load_sft_dataset(
300380
eval_file, context_length, template_name=base_model, tokenizer=tokenizer
301381
)
302-
model = load_model(model_name_or_path, lora_only, revision, val_args)
382+
model = load_model(
383+
model_name_or_path, lora_only, revision, val_args, cached_lora
384+
)
303385
# if model is not loaded, mark the assignment as failed and return
304386
if model is None:
305387
fed_ledger.mark_assignment_as_failed(assignment_id)

0 commit comments

Comments
 (0)