|
45 | 45 | "import os\n", |
46 | 46 | "\n", |
47 | 47 | "from transformers import AutoTokenizer\n", |
| 48 | + "\n", |
48 | 49 | "from optimum.intel import OVModelForCausalLM, OVWeightQuantizationConfig" |
49 | 50 | ] |
50 | 51 | }, |
|
211 | 212 | "source": [ |
212 | 213 | "from transformers import TextStreamer\n", |
213 | 214 | "\n", |
| 215 | + "\n", |
214 | 216 | "# Tokenize the sample\n", |
215 | 217 | "inputs = tokenizer([sample], return_tensors='pt')\n", |
216 | 218 | "\n", |
|
294 | 296 | "\n", |
295 | 297 | "\n", |
296 | 298 | "# Tokenize the sample\n", |
297 | | - "inputs = tokenizer([sample], return_tensors='pt') \n", |
| 299 | + "inputs = tokenizer([sample], return_tensors='pt')\n", |
298 | 300 | "\n", |
299 | 301 | "out = stateless_model.generate(\n", |
300 | 302 | " **inputs,\n", |
301 | 303 | " max_new_tokens=128,\n", |
302 | 304 | " streamer=TextStreamer(tokenizer=tokenizer, skip_special_tokens=True),\n", |
303 | 305 | " pad_token_id=tokenizer.eos_token_id,\n", |
304 | 306 | " prompt_lookup_num_tokens=3,\n", |
305 | | - ") " |
| 307 | + ")" |
306 | 308 | ] |
307 | 309 | }, |
308 | 310 | { |
|
442 | 444 | "outputs": [], |
443 | 445 | "source": [ |
444 | 446 | "from functools import wraps\n", |
| 447 | + "\n", |
445 | 448 | "import numpy as np\n", |
446 | 449 | "\n", |
447 | 450 | "\n", |
|
458 | 461 | " if len(self.seq_lens) > 0 or len(self.win_sizes) > 0:\n", |
459 | 462 | " raise RuntimeError(\"Always use a new instance, don't reuse!\")\n", |
460 | 463 | " self.model_forward = self.model.forward\n", |
461 | | - " \n", |
| 464 | + "\n", |
462 | 465 | " @wraps(self.model_forward)\n", |
463 | 466 | " def forward_wrapper(**kwargs):\n", |
464 | 467 | " self.seq_lens[-1].append(kwargs.get(\"attention_mask\").shape[-1])\n", |
465 | 468 | " self.win_sizes[-1].append(kwargs.get(\"input_ids\").shape[-1] - 1)\n", |
466 | 469 | " return self.model_forward(**kwargs)\n", |
467 | | - " \n", |
| 470 | + "\n", |
468 | 471 | " self.model.forward = forward_wrapper\n", |
469 | | - " \n", |
| 472 | + "\n", |
470 | 473 | " # wrap generate method\n", |
471 | 474 | " self.model_generate = self.model.generate\n", |
472 | 475 | "\n", |
|
494 | 497 | " self.seq_lens = [sl[1:] for sl in self.seq_lens]\n", |
495 | 498 | " # Add window size for output to ease calculation later\n", |
496 | 499 | " for ws, sl in zip(self.win_sizes, self.seq_lens):\n", |
497 | | - " ws.append(0) \n", |
| 500 | + " ws.append(0)\n", |
498 | 501 | "\n", |
499 | 502 | " def acceptance_rate(self, return_mean=True, normalize=False):\n", |
500 | 503 | " # ar_per_win = ((cur_seq_len - cur_win_size) - (prev_seq_len - prev_win_size) - 1) / prev_win_size\n", |
|
533 | 536 | "metadata": {}, |
534 | 537 | "outputs": [], |
535 | 538 | "source": [ |
536 | | - "from tqdm import tqdm\n", |
537 | 539 | "from datasets import load_dataset\n", |
| 540 | + "from tqdm import tqdm\n", |
| 541 | + "\n", |
538 | 542 | "\n", |
539 | 543 | "dataset_name = \"openai_humaneval\"\n", |
540 | 544 | "dataset_subset_name = None\n", |
|
590 | 594 | "from threading import Thread\n", |
591 | 595 | "\n", |
592 | 596 | "from transformers import (\n", |
593 | | - " TextIteratorStreamer,\n", |
| 597 | + " GenerationConfig,\n", |
594 | 598 | " StoppingCriteria,\n", |
595 | 599 | " StoppingCriteriaList,\n", |
596 | | - " GenerationConfig,\n", |
| 600 | + " TextIteratorStreamer,\n", |
597 | 601 | ")\n", |
598 | 602 | "\n", |
599 | 603 | "\n", |
|
690 | 694 | " prompt_char = \"▌\"\n", |
691 | 695 | " history[-1][1] = prompt_char\n", |
692 | 696 | " yield history, \"Status: Generating...\", *([gr.update(interactive=False)] * 4)\n", |
693 | | - " \n", |
| 697 | + "\n", |
694 | 698 | " streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)\n", |
695 | 699 | "\n", |
696 | 700 | " # Create a stopping criteria to prevent the model from playing the role of the user aswell.\n", |
|
770 | 774 | "source": [ |
771 | 775 | "import gradio as gr\n", |
772 | 776 | "\n", |
| 777 | + "\n", |
773 | 778 | "try:\n", |
774 | 779 | " demo.close()\n", |
775 | 780 | "except:\n", |
|
808 | 813 | " history: conversation history\n", |
809 | 814 | " Returns:\n", |
810 | 815 | " updated history\n", |
811 | | - " \"\"\" \n", |
| 816 | + " \"\"\"\n", |
812 | 817 | " history[-1][1] = None\n", |
813 | 818 | " return history\n", |
814 | 819 | "\n", |
|
0 commit comments