This repository implements Beyond Prompt Engineering: A Reinforced Token-Level Input Refinement for Large Language Models.
pip install -r requriment.txt
All datasets come from hugging face. Just using datasets pakage to get any datasets you want.
raw_datasets = datasets.load_dataset('hails/mmlu_no_train')
The base models used in expriment are from hugging face, so use the following code to download the models to your own path.
model_path = 'Llama-2-7b-hf'
model = AutoModelForCausalLM.from_pretrained(model_path, output_hidden_states=True)
tokenizer = AutoTokenizer.from_pretrained(model_path)
We use the Language Model Evaluation Harness method to evaluation.https://github.com/EleutherAI/lm-evaluation-harness
First install lm-eval.
git clone https://github.com/EleutherAI/lm-evaluation-harness
cd lm-evaluation-harness
pip install -e .
Then, use lm-eval to get the baseline result.
lm_eval --model hf --model_args pretrained=meta-llama/Llama-2-7b-hf --tasks mmlu --device cuda:0 --batch_size 32
We add a function (change_emb) in _model_call of lm_eval/models/huggingface.py. The function is from test_func.py
def _model_call(self, inps, attn_mask=None, labels=None):
with torch.no_grad():
if attn_mask is not None or labels is not None:
assert attn_mask is not None and labels is not None
assert self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM
return self.model(
input_ids=inps, attention_mask=attn_mask, labels=labels
).logits
else:
assert self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM
# Add begin
from lm_eval.models.LLM_filter.test_func import change_emb
B,L = inps.size()
if L<=1024:
inps, new_att_mask = change_emb(inputs_ids=inps, tokenizer=self.tokenizer, device=self.device)
else:
inputs_id = inps[:,:1024]
inputs_id, new_att_mask = change_emb(inputs_ids=inputs_id, tokenizer=self.tokenizer, device=self.device)
inps[:,:1024] = inputs_id
# Add end
logits = self.model(inps).logits
return logits
Use the main.py to train and set the model_path to change the model you want to use. The support models are only qwen2, Gemma, Llama-2, Llama-3, Vicuna.
python main.py
Then you will get the model.pth in /result.