Skip to content

Commit fc07060

Browse files
Local inference (#27)
* First pass at load and run * Return first decoded entry, load docstring * Initial implementation for adapter config patcher * Add adapter config patcher to util init * First pass at script for loading and running a model with CLI * Add base model override to cli for inference script * Return immediately if no overrides are given * Add adapter config overrides to inference script * CLI support for processing one or more texts * Docstring updates for load / run * Refactor train into tuned model classmethod * Move inference CLI to a separate script * Infer device for inference * adapter config docstrings Signed-off-by: Alex-Brooks <[email protected]> * Add inference instructions Signed-off-by: Alex-Brooks <[email protected]> * Add max new tokens as an arg to run inference Signed-off-by: Alex-Brooks <[email protected]> * Split inference and tuning back apart Signed-off-by: Alex-Brooks <[email protected]> * Consolidate inference cli and tuned model class Signed-off-by: Alex-Brooks <[email protected]> * Consolidate adapter config patcher into inference script Signed-off-by: Alex-Brooks <[email protected]> * Move inference script outside of tuning package Signed-off-by: Alex-Brooks <[email protected]> * Update readme inference instructions Signed-off-by: Alex-Brooks <[email protected]> --------- Signed-off-by: Alex-Brooks <[email protected]>
1 parent 304b179 commit fc07060

File tree

3 files changed

+307
-25
lines changed

3 files changed

+307
-25
lines changed

README.md

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,56 @@ tuning/sft_trainer.py \
113113

114114
For `GPTBigCode` models, Hugging Face has enabled Flash v2 and one can simply replace the `'LlamaDecoderLayer'` with `'GPTBigCodeBlock'` in `tuning/config/fsdp_config.json` for proper sharding of the model.
115115

116+
## Inference
117+
Currently, we do *not* offer inference support as part of the library, but we provide a standalone script for running inference on tuned models for testing purposes. For a full list of options run `python scripts/run_inference.py --help`. Note that no data formatting / templating is applied at inference time.
118+
119+
### Running a single example
120+
If you want to run a single example through a model, you can pass it with the `--text` flag.
121+
122+
```bash
123+
python scripts/run_inference.py \
124+
--model my_checkpoint \
125+
--text "This is a text the model will run inference on" \
126+
--max_new_tokens 50 \
127+
--out_file result.json
128+
```
129+
130+
### Running multiple examples
131+
To run multiple examples, pass a path to a file containing each source text as its own line. Example:
132+
133+
Contents of `source_texts.txt`
134+
```
135+
This is the first text to be processed.
136+
And this is the second text to be processed.
137+
```
138+
139+
```bash
140+
python scripts/run_inference.py \
141+
--model my_checkpoint \
142+
--text_file source_texts.txt \
143+
--max_new_tokens 50 \
144+
--out_file result.json
145+
```
146+
147+
### Inference Results Format
148+
After running the inference script, the specified `--out_file` will be a JSON file, where each text has the original input string and the predicted output string, as follows. Note that due to the implementation of `.generate()` in Transformers, in general, the input string will be contained in the output string as well.
149+
```
150+
[
151+
{
152+
"input": "{{Your input string goes here}}",
153+
"output": "{{Generate result of processing your input string goes here}}"
154+
},
155+
...
156+
]
157+
```
158+
159+
### Changing the Base Model for Inference
160+
If you tuned a model using a *local* base model, then a machine-specific path will be saved into your checkpoint by Peft, specifically the `adapter_config.json`. This can be problematic if you are running inference on a different machine than you used for tuning.
161+
162+
As a workaround, the CLI for inference provides an arg for `--base_model_name_or_path`, where a new base model may be passed to run inference with. This will patch the `base_model_name_or_path` in your checkpoint's `adapter_config.json` while loading the model, and restore it to its original value after completion. Alternatively, if you like, you can change the config's value yourself.
163+
164+
NOTE: This can also be an issue for tokenizers (with the `tokenizer_name_or_path` config entry). We currently do not allow tokenizer patching since the tokenizer can also be explicitly configured within the base model and checkpoint model, but may choose to expose an override for the `tokenizer_name_or_path` in the future.
165+
116166
## Validation
117167

118168
We can use [`lm-evaluation-harness`](https://github.com/EleutherAI/lm-evaluation-harness) from EleutherAI for evaluating the generated model. For example, for the Llama-13B model, using the above command and the model at the end of Epoch 5, we evaluated MMLU score to be `53.9` compared to base model to be `52.8`.

scripts/run_inference.py

Lines changed: 234 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,234 @@
1+
"""CLI for running loading a tuned model and running one or more inference calls on it.
2+
3+
NOTE: For the moment, this script is intentionally written to contain all dependencies for two
4+
reasons:
5+
- to keep it portable and not deal with managing multiple local packages.
6+
- because we don't currently plan on supporting inference as a library; i.e., this is only for
7+
testing.
8+
9+
If these things change in the future, we should consider breaking it up.
10+
"""
11+
import argparse
12+
import json
13+
import os
14+
from peft import AutoPeftModelForCausalLM
15+
import torch
16+
from tqdm import tqdm
17+
from transformers import AutoTokenizer
18+
19+
20+
### Utilities
21+
class AdapterConfigPatcher:
22+
"""Adapter config patcher is a context manager for patching overrides into a config;
23+
it will locate the adapter_config.json in a file and patch a dict of provided overrides
24+
when inside the dict block, and restore them when it leaves. This DOES actually write to
25+
the file, so it's NOT safe to use in parallel inference runs.
26+
27+
Example:
28+
overrides = {"base_model_name_or_path": "foo"}
29+
with AdapterConfigPatcher(checkpoint_path, overrides):
30+
# When loaded in this block, the config's base_model_name_or_path is "foo"
31+
peft_model = AutoPeftModelForCausalLM.from_pretrained(checkpoint_path)
32+
"""
33+
def __init__(self, checkpoint_path: str, overrides: dict):
34+
self.checkpoint_path = checkpoint_path
35+
self.overrides = overrides
36+
self.config_path = AdapterConfigPatcher._locate_adapter_config(self.checkpoint_path)
37+
# Values that we will patch later on
38+
self.patched_values = {}
39+
40+
@staticmethod
41+
def _locate_adapter_config(checkpoint_path: str) -> str:
42+
"""Given a path to a tuned checkpoint, tries to find the adapter_config
43+
that will be loaded through the Peft auto model API.
44+
45+
Args:
46+
checkpoint_path: str
47+
Checkpoint model, which presumably holds an adapter config.
48+
49+
Returns:
50+
str
51+
Path to the located adapter_config file.
52+
"""
53+
config_path = os.path.join(checkpoint_path, "adapter_config.json")
54+
if not os.path.isfile(config_path):
55+
raise FileNotFoundError(f"Could not locate adapter config: {config_path}")
56+
return config_path
57+
58+
def _apply_config_changes(self, overrides: dict) -> dict:
59+
"""Applies a patch to a config with some override dict, returning the values
60+
that we patched over so that they may be restored later.
61+
62+
Args:
63+
overrides: dict
64+
Overrides to write into the adapter_config.json. Currently, we
65+
require all override keys to be defined in the config being patched.
66+
67+
Returns:
68+
dict
69+
Dict containing the values that we patched over.
70+
"""
71+
# If we have no overrides, this context manager is a noop; no need to do anything
72+
if not overrides:
73+
return {}
74+
with open(self.config_path, "r") as config_file:
75+
adapter_config = json.load(config_file)
76+
overridden_values = self._get_old_config_values(adapter_config, overrides)
77+
adapter_config = {**adapter_config, **overrides}
78+
with open(self.config_path, "w") as config_file:
79+
json.dump(adapter_config, config_file, indent=4)
80+
return overridden_values
81+
82+
@staticmethod
83+
def _get_old_config_values(adapter_config: dict, overrides: dict) -> dict:
84+
"""Grabs the existing config subdict that we are going to clobber from the
85+
loaded adapter_config.
86+
87+
Args:
88+
adapter_config: dict
89+
Adapter config whose values we are interested in patching.
90+
overrides: dict
91+
Dict of overrides, containing keys definined in the adapter_config with
92+
new values.
93+
94+
Returns:
95+
dict
96+
The subdictionary of adapter_config, containing the keys in overrides,
97+
with the values that we are going to replace.
98+
"""
99+
# For now, we only expect to patch the base model; this may change in the future,
100+
# but ensure that anything we are patching is defined in the original config
101+
if not set(overrides.keys()).issubset(set(adapter_config.keys())):
102+
raise KeyError("Adapter config overrides must be set in the config being patched")
103+
return {key: adapter_config[key] for key in overrides}
104+
105+
def __enter__(self):
106+
"""Apply the config overrides and saved the patched values."""
107+
self.patched_values = self._apply_config_changes(self.overrides)
108+
109+
def __exit__(self, exc_type, exc_value, exc_tb):
110+
"""Apply the patched values over our exported overrides."""
111+
self._apply_config_changes(self.patched_values)
112+
113+
114+
### Funcs for loading and running models
115+
class TunedCausalLM:
116+
def __init__(self, model, tokenizer, device):
117+
self.peft_model = model
118+
self.tokenizer = tokenizer
119+
self.device = device
120+
121+
@classmethod
122+
def load(cls, checkpoint_path: str, base_model_name_or_path: str=None) -> "TunedCausalLM":
123+
"""Loads an instance of this model.
124+
125+
Args:
126+
checkpoint_path: str
127+
Checkpoint model to be loaded, which is a directory containing an
128+
adapter_config.json.
129+
base_model_name_or_path: str [Default: None]
130+
Override for the base model to be used.
131+
132+
By default, the paths for the base model and tokenizer are contained within the adapter
133+
config of the tuned model. Note that in this context, a path may refer to a model to be
134+
downloaded from HF hub, or a local path on disk, the latter of which we must be careful
135+
with when using a model that was written on a different device.
136+
137+
Returns:
138+
TunedCausalLM
139+
An instance of this class on which we can run inference.
140+
"""
141+
overrides = {"base_model_name_or_path": base_model_name_or_path} if base_model_name_or_path is not None else {}
142+
tokenizer = AutoTokenizer.from_pretrained(checkpoint_path)
143+
# Apply the configs to the adapter config of this model; if no overrides
144+
# are provided, then the context manager doesn't have any effect.
145+
with AdapterConfigPatcher(checkpoint_path, overrides):
146+
try:
147+
peft_model = AutoPeftModelForCausalLM.from_pretrained(checkpoint_path)
148+
except OSError as e:
149+
print("Failed to initialize checkpoint model!")
150+
raise e
151+
device = "cuda" if torch.cuda.is_available() else None
152+
print(f"Inferred device: {device}")
153+
peft_model.to(device)
154+
return cls(peft_model, tokenizer, device)
155+
156+
157+
def run(self, text: str, *, max_new_tokens: int) -> str:
158+
"""Runs inference on an instance of this model.
159+
160+
Args:
161+
text: str
162+
Text on which we want to run inference.
163+
max_new_tokens: int
164+
Max new tokens to use for inference.
165+
166+
Returns:
167+
str
168+
Text generation result.
169+
"""
170+
tok_res = self.tokenizer(text, return_tensors="pt")
171+
input_ids = tok_res.input_ids.to(self.device)
172+
173+
peft_outputs = self.peft_model.generate(input_ids=input_ids, max_new_tokens=max_new_tokens)
174+
decoded_result = self.tokenizer.batch_decode(peft_outputs, skip_special_tokens=False)[0]
175+
return decoded_result
176+
177+
178+
### Main & arg parsing
179+
def main():
180+
parser = argparse.ArgumentParser(
181+
description="Loads a tuned model and runs an inference call(s) through it"
182+
)
183+
parser.add_argument("--model", help="Path to tuned model to be loaded", required=True)
184+
parser.add_argument(
185+
"--out_file",
186+
help="JSON file to write results to",
187+
default="inference_result.json",
188+
)
189+
parser.add_argument(
190+
"--base_model_name_or_path",
191+
help="Override for base model to be used [default: value in model adapter_config.json]",
192+
default=None
193+
)
194+
parser.add_argument(
195+
"--max_new_tokens",
196+
help="max new tokens to use for inference",
197+
type=int,
198+
default=20,
199+
)
200+
group = parser.add_mutually_exclusive_group(required=True)
201+
group.add_argument("--text", help="Text to run inference on")
202+
group.add_argument("--text_file", help="File to be processed where each line is a text to run inference on")
203+
args = parser.parse_args()
204+
# If we passed a file, check if it exists before doing anything else
205+
if args.text_file and not os.path.isfile(args.text_file):
206+
raise FileNotFoundError(f"Text file: {args.text_file} does not exist!")
207+
208+
# Load the model
209+
loaded_model = TunedCausalLM.load(
210+
checkpoint_path=args.model,
211+
base_model_name_or_path=args.base_model_name_or_path,
212+
)
213+
214+
# Run inference on the text; if multiple were provided, process them all
215+
if args.text:
216+
texts = [args.text]
217+
else:
218+
with open(args.text_file, "r") as text_file:
219+
texts = [line.strip() for line in text_file.readlines()]
220+
221+
# TODO: we should add batch inference support
222+
results = [
223+
{"input": text, "output": loaded_model.run(text, max_new_tokens=args.max_new_tokens)}
224+
for text in tqdm(texts)
225+
]
226+
227+
# Export the results to a file
228+
with open(args.out_file, "w") as out_file:
229+
json.dump(results, out_file, sort_keys=True, indent=4)
230+
231+
print(f"Exported results to: {args.out_file}")
232+
233+
if __name__ == "__main__":
234+
main()

tuning/sft_trainer.py

Lines changed: 23 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,21 @@
1-
from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer, LlamaTokenizerFast, GPTNeoXTokenizerFast, GPT2Tokenizer
1+
import os
2+
from typing import Optional, Union
3+
4+
import datasets
25
import fire
3-
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM
4-
import transformers
6+
from peft.utils.other import fsdp_auto_wrap_policy
57
import torch
6-
import datasets
7-
8-
from tuning.data import tokenizer_data_utils
8+
import transformers
9+
from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer, LlamaTokenizerFast, GPTNeoXTokenizerFast, GPT2Tokenizer
10+
from transformers.utils import logging
11+
from transformers import TrainerCallback
12+
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM
13+
from tuning.aim_loader import get_aimstack_callback
914
from tuning.config import configs, peft_config
15+
from tuning.data import tokenizer_data_utils
1016
from tuning.utils.config_utils import get_hf_peft_config
1117
from tuning.utils.data_type_utils import get_torch_dtype
1218

13-
from tuning.aim_loader import get_aimstack_callback
14-
from transformers.utils import logging
15-
from dataclasses import asdict
16-
from typing import Optional, Union
17-
18-
from peft import LoraConfig
19-
import os
20-
from transformers import TrainerCallback
21-
from peft.utils.other import fsdp_auto_wrap_policy
22-
2319
class PeftSavingCallback(TrainerCallback):
2420
def on_save(self, args, state, control, **kwargs):
2521
checkpoint_path = os.path.join(args.output_dir, f"checkpoint-{state.global_step}")
@@ -29,21 +25,22 @@ def on_save(self, args, state, control, **kwargs):
2925
os.remove(os.path.join(checkpoint_path, "pytorch_model.bin"))
3026

3127

28+
3229
def train(
33-
model_args: configs.ModelArguments,
34-
data_args: configs.DataArguments,
35-
train_args: configs.TrainingArguments,
36-
peft_config: Optional[Union[peft_config.LoraConfig, peft_config.PromptTuningConfig]] = None,
37-
):
30+
model_args: configs.ModelArguments,
31+
data_args: configs.DataArguments,
32+
train_args: configs.TrainingArguments,
33+
peft_config: Optional[Union[peft_config.LoraConfig, peft_config.PromptTuningConfig]] = None,
34+
):
3835
"""Call the SFTTrainer
3936
4037
Args:
4138
model_args: tuning.config.configs.ModelArguments
4239
data_args: tuning.config.configs.DataArguments
4340
train_args: tuning.config.configs.TrainingArguments
4441
peft_config: peft_config.LoraConfig for Lora tuning | \
45-
peft_config.PromptTuningConfig for prompt tuning | \
46-
None for fine tuning
42+
peft_config.PromptTuningConfig for prompt tuning | \
43+
None for fine tuning
4744
The peft configuration to pass to trainer
4845
"""
4946
run_distributed = int(os.environ.get("WORLD_SIZE", "1")) > 1
@@ -62,7 +59,7 @@ def train(
6259
train_args.fsdp_config = {'xla':False}
6360

6461
task_type = "CAUSAL_LM"
65-
model = transformers.AutoModelForCausalLM.from_pretrained(
62+
model = AutoModelForCausalLM.from_pretrained(
6663
model_args.model_name_or_path,
6764
cache_dir=train_args.cache_dir,
6865
torch_dtype=get_torch_dtype(model_args.torch_dtype),
@@ -74,7 +71,7 @@ def train(
7471
model.gradient_checkpointing_enable()
7572

7673
# TODO: Move these to a config as well
77-
tokenizer = transformers.AutoTokenizer.from_pretrained(
74+
tokenizer = AutoTokenizer.from_pretrained(
7875
model_args.model_name_or_path,
7976
cache_dir=train_args.cache_dir,
8077
use_fast = True
@@ -170,6 +167,7 @@ def train(
170167
trainer.accelerator.state.fsdp_plugin.auto_wrap_policy = fsdp_auto_wrap_policy(model)
171168
trainer.train()
172169

170+
173171
def main(**kwargs):
174172
parser = transformers.HfArgumentParser(dataclass_types=(configs.ModelArguments,
175173
configs.DataArguments,

0 commit comments

Comments
 (0)