-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_model.py
More file actions
63 lines (54 loc) · 2.06 KB
/
test_model.py
File metadata and controls
63 lines (54 loc) · 2.06 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
MODEL_NAME = "Qwen/Qwen3-4B"
ADAPTER_PATH = "./taikai-support-model"
# Load base model + LoRA adapter
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
base_model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=torch.float16,
)
model = PeftModel.from_pretrained(base_model, ADAPTER_PATH)
model.eval()
# Move to MPS if available
device = "mps" if torch.backends.mps.is_available() else "cpu"
model = model.to(device)
def ask(question):
messages = [
{"role": "system", "content": "You are a helpful customer support assistant for TAIKAI, a hackathon and open innovation platform. Answer questions accurately and concisely."},
{"role": "user", "content": question},
]
# Disable thinking mode for direct answers (no <think> blocks)
input_text = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True,
enable_thinking=False,
)
inputs = tokenizer(input_text, return_tensors="pt").to(device)
with torch.no_grad():
output = model.generate(
**inputs,
max_new_tokens=256,
temperature=0.7,
top_p=0.9,
do_sample=True,
)
response = tokenizer.decode(output[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
return response
# Test with various questions — mixing exact FAQ phrasing with natural user language
test_questions = [
"How do I create a TAIKAI account?",
"yo how do i get into a hackathon",
"my project wont publish what do i do",
"how does the voting system work for judges?",
"Can I withdraw my LX tokens?",
"i signed up with google but now i cant find my account",
"What is a POP and how do I mint one?",
"how do i find teammates for a hackathon",
"Is there a way to reset my 2FA?",
"what's the difference between a challenge and a hackathon on taikai",
]
for q in test_questions:
print(f"\nQ: {q}")
print(f"A: {ask(q)}")
print("-" * 60)