Skip to content

Commit c6aa165

Browse files
committed
modify the initial run file
1 parent 03db245 commit c6aa165

File tree

1 file changed

+28
-10
lines changed

1 file changed

+28
-10
lines changed

examples/offline_inference_rerope.py

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,17 @@
2323
def setup_environment_variables():
2424
os.environ["VLLM_USE_V1"] = "1"
2525
os.environ["PYTHONHASHSEED"] = "123456"
26-
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,6,7"
26+
2727
os.environ["VLLM_ATTENTION_BACKEND"] = "TRITON_ATTN_VLLM_V1"
2828
os.environ["REROPE_WINDOW"] = "32768"
2929
os.environ["TRAINING_LENGTH"] = "32768"
3030

31-
global data_dir
32-
data_dir = os.getenv("DATA_DIR", "/home/externals/wangwenxin21/wx_data")
3331

32+
global data_dir
33+
data_dir = os.getenv("DATA_DIR", "/home/data/kv_cache")
34+
data_dir = input(
35+
"Enter the directory for UCMStore to save kv cache, e.g. /home/data/kv_cache: "
36+
)
3437
if not os.path.isdir(data_dir):
3538
create = input(f"Directory {data_dir} dose not exist. Create it? (Y/n): ")
3639
if create.lower() == "y":
@@ -63,13 +66,13 @@ def build_llm_with_uc(module_path: str, name: str, model: str):
6366
model=model,
6467
kv_transfer_config=ktc,
6568
hf_overrides={
66-
"max_position_embeddings": 430080,
69+
"max_position_embeddings": 327680,
6770
},
68-
gpu_memory_utilization=0.8,
71+
gpu_memory_utilization=0.9,
6972
max_num_batched_tokens=8192,
7073
block_size=16,
7174
enforce_eager=True,
72-
tensor_parallel_size=4,
75+
tensor_parallel_size=2,
7376
)
7477

7578
llm = LLM(**asdict(llm_args))
@@ -98,24 +101,39 @@ def print_output(
98101
def main():
99102
module_path = "ucm.integration.vllm.ucm_connector"
100103
name = "UCMConnector"
101-
model = os.getenv("MODEL_PATH", "/home/wx/models/Qwen2.5-14B-Instruct")
104+
model = os.getenv("MODEL_PATH", "/home/models/Qwen2.5-14B-Instruct")
105+
if not os.path.isdir(model):
106+
model = input("Enter path to model, e.g. /home/models/Qwen2.5-14B-Instruct: ")
107+
if not os.path.isdir(model):
108+
print("Exiting. Incorrect model_path")
109+
sys.exit(1)
102110

103111
tokenizer = AutoTokenizer.from_pretrained(model, use_chat_template=True)
104112
setup_environment_variables()
105113

106114
with build_llm_with_uc(module_path, name, model) as llm:
107115

108116
data_all = []
117+
path_to_dataset = os.getenv(
118+
"DATASET_PATH", "/home/data/Longbench/data/multifieldqa_zh.jsonl"
119+
)
120+
if not os.path.isfile(path_to_dataset):
121+
path_to_dataset = input(
122+
"Enter path to one of the longbench dataset, e.g. /home/data/Longbench/data/multifieldqa_zh.jsonl: "
123+
)
124+
if not os.path.isfile(path_to_dataset):
125+
print("Exiting. Incorrect dataset path")
126+
sys.exit(1)
109127
with open(
110-
"/home/wx/va_clean/data/multifieldqa_zh.jsonl", "r", encoding="utf-8"
128+
path_to_dataset, "r", encoding="utf-8"
111129
) as f:
112130
for line in f:
113131
data_all.append(json.loads(line))
114132

115133
materials = []
116134
questions = []
117135
references = []
118-
batch_size = 75
136+
batch_size = 30
119137
num_batch = 2
120138
for idx in range(num_batch):
121139
data = data_all[idx * batch_size : (idx + 1) * batch_size]
@@ -151,7 +169,7 @@ def main():
151169
"【文本内容开始】\n"
152170
f"{material}\n"
153171
"【文本内容结束】\n\n"
154-
"请回答以下问题\n"
172+
"请直接回答以下问题\n"
155173
f"{question}"
156174
)
157175

0 commit comments

Comments
 (0)