|
| 1 | +import json |
| 2 | +import streamlit as st |
| 3 | +import torch |
| 4 | +from transformers import AutoModelForCausalLM, AutoTokenizer |
| 5 | +from transformers.generation.utils import GenerationConfig |
| 6 | + |
| 7 | +st.set_page_config(page_title="MiniMind-V1 Demo(无历史上文)") |
| 8 | +st.title("MiniMind-V1 Demo(无历史上文)") |
| 9 | + |
| 10 | +model_id = "minimind-v1" |
| 11 | + |
| 12 | +# ----------------------------------------------------------------------------- |
| 13 | +temperature = 0.7 |
| 14 | +top_k = 8 |
| 15 | +max_seq_len = 1 * 1024 |
| 16 | +# ----------------------------------------------------------------------------- |
| 17 | + |
| 18 | + |
| 19 | +@st.cache_resource |
| 20 | +def load_model_tokenizer(): |
| 21 | + model = AutoModelForCausalLM.from_pretrained( |
| 22 | + model_id, |
| 23 | + trust_remote_code=True |
| 24 | + ) |
| 25 | + tokenizer = AutoTokenizer.from_pretrained( |
| 26 | + model_id, |
| 27 | + use_fast=False, |
| 28 | + trust_remote_code=True |
| 29 | + ) |
| 30 | + model = model.eval() |
| 31 | + generation_config = GenerationConfig.from_pretrained(model_id) |
| 32 | + return model, tokenizer, generation_config |
| 33 | + |
| 34 | + |
| 35 | +def clear_chat_messages(): |
| 36 | + del st.session_state.messages |
| 37 | + |
| 38 | + |
| 39 | +def init_chat_messages(): |
| 40 | + with st.chat_message("assistant", avatar='🤖'): |
| 41 | + st.markdown("您好,我是由Joya开发的MiniMind,很高兴为您服务😄") |
| 42 | + |
| 43 | + if "messages" in st.session_state: |
| 44 | + for message in st.session_state.messages: |
| 45 | + avatar = "🧑💻" if message["role"] == "user" else "🤖" |
| 46 | + with st.chat_message(message["role"], avatar=avatar): |
| 47 | + st.markdown(message["content"]) |
| 48 | + else: |
| 49 | + st.session_state.messages = [] |
| 50 | + |
| 51 | + return st.session_state.messages |
| 52 | + |
| 53 | + |
| 54 | +# max_new_tokens = st.sidebar.slider("max_new_tokens", 0, 1024, 512, step=1) |
| 55 | +# top_p = st.sidebar.slider("top_p", 0.0, 1.0, 0.8, step=0.01) |
| 56 | +# top_k = st.sidebar.slider("top_k", 0, 100, 0, step=1) |
| 57 | +# temperature = st.sidebar.slider("temperature", 0.0, 2.0, 1.0, step=0.01) |
| 58 | +# do_sample = st.sidebar.checkbox("do_sample", value=False) |
| 59 | + |
| 60 | + |
| 61 | +def main(): |
| 62 | + model, tokenizer, generation_config = load_model_tokenizer() |
| 63 | + messages = init_chat_messages() |
| 64 | + |
| 65 | + if prompt := st.chat_input("Shift + Enter 换行, Enter 发送"): |
| 66 | + with st.chat_message("user", avatar='🧑💻'): |
| 67 | + st.markdown(prompt) |
| 68 | + messages.append({"role": "user", "content": prompt}) |
| 69 | + with st.chat_message("assistant", avatar='🤖'): |
| 70 | + placeholder = st.empty() |
| 71 | + |
| 72 | + chat_messages = [] |
| 73 | + chat_messages.append({"role": "user", "content": prompt}) |
| 74 | + # print(messages) |
| 75 | + new_prompt = tokenizer.apply_chat_template( |
| 76 | + chat_messages, |
| 77 | + tokenize=False, |
| 78 | + add_generation_prompt=True |
| 79 | + )[-(max_seq_len - 1):] |
| 80 | + |
| 81 | + x = tokenizer(new_prompt).data['input_ids'] |
| 82 | + x = (torch.tensor(x, dtype=torch.long)[None, ...]) |
| 83 | + |
| 84 | + response = '' |
| 85 | + |
| 86 | + with torch.no_grad(): |
| 87 | + res_y = model.generate(x, tokenizer.eos_token_id, max_new_tokens=max_seq_len, temperature=temperature, |
| 88 | + top_k=top_k, stream=True) |
| 89 | + try: |
| 90 | + y = next(res_y) |
| 91 | + except StopIteration: |
| 92 | + return |
| 93 | + |
| 94 | + history_idx = 0 |
| 95 | + while y != None: |
| 96 | + answer = tokenizer.decode(y[0].tolist()) |
| 97 | + if answer and answer[-1] == '�': |
| 98 | + try: |
| 99 | + y = next(res_y) |
| 100 | + except: |
| 101 | + break |
| 102 | + continue |
| 103 | + # print(answer) |
| 104 | + if not len(answer): |
| 105 | + try: |
| 106 | + y = next(res_y) |
| 107 | + except: |
| 108 | + break |
| 109 | + continue |
| 110 | + placeholder.markdown(answer) |
| 111 | + response = answer |
| 112 | + try: |
| 113 | + y = next(res_y) |
| 114 | + except: |
| 115 | + break |
| 116 | + |
| 117 | + # if contain_history_chat: |
| 118 | + # assistant_answer = answer.replace(new_prompt, "") |
| 119 | + # messages.append({"role": "assistant", "content": assistant_answer}) |
| 120 | + |
| 121 | + messages.append({"role": "assistant", "content": response}) |
| 122 | + # print("messages: ", json.dumps(response, ensure_ascii=False), flush=True) |
| 123 | + |
| 124 | + st.button("清空对话", on_click=clear_chat_messages) |
| 125 | + |
| 126 | + |
| 127 | +if __name__ == "__main__": |
| 128 | + main() |
0 commit comments