1
1
import sys
2
2
from pathlib import Path
3
- from typing import List , Literal , TypedDict
3
+ from typing import List , TypedDict
4
4
from unittest .mock import patch
5
5
6
6
import pytest
7
7
import torch
8
8
from llama_recipes .inference .chat_utils import read_dialogs_from_file
9
9
10
10
ROOT_DIR = Path (__file__ ).parents [2 ]
11
- CHAT_COMPLETION_DIR = ROOT_DIR / "recipes/inference/local_inference/chat_completion/"
11
+ CHAT_COMPLETION_DIR = ROOT_DIR / "recipes/quickstart/ inference/local_inference/chat_completion/"
12
12
13
13
sys .path = [CHAT_COMPLETION_DIR .as_posix ()] + sys .path
14
14
15
- Role = Literal ["user" , "assistant" ]
16
-
17
-
18
- class Message (TypedDict ):
19
- role : Role
20
- content : str
21
-
22
-
23
- Dialog = List [Message ]
24
-
25
- B_INST , E_INST = "[INST]" , "[/INST]"
26
- B_SYS , E_SYS = "<<SYS>>\n " , "\n <</SYS>>\n \n "
27
-
15
+ default_system_prompt = [{"role" : "system" , "content" : "Cutting Knowledge Date: December 2023\n Today Date: 26 Jul 2024\n \n " }]
28
16
29
17
def _encode_header (message , tokenizer ):
30
18
tokens = []
31
- tokens .extend (tokenizer .encode ("<|start_header_id|>" ))
32
- tokens .extend (tokenizer .encode (message ["role" ]))
33
- tokens .extend (tokenizer .encode ("<|end_header_id|>" ))
34
- tokens .extend (tokenizer .encode ("\n \n " ))
19
+ tokens .extend (tokenizer .encode ("<|start_header_id|>" , add_special_tokens = False ))
20
+ tokens .extend (tokenizer .encode (message ["role" ], add_special_tokens = False ))
21
+ tokens .extend (tokenizer .encode ("<|end_header_id|>" , add_special_tokens = False ))
22
+ tokens .extend (tokenizer .encode ("\n \n " , add_special_tokens = False ))
35
23
return tokens
36
24
37
25
38
26
def _encode_message (message , tokenizer ):
39
27
tokens = _encode_header (message , tokenizer )
40
- tokens .extend (tokenizer .encode (message ["content" ]. strip () ))
41
- tokens .extend (tokenizer .encode ("<|eot_id|>" ))
28
+ tokens .extend (tokenizer .encode (message ["content" ], add_special_tokens = False ))
29
+ tokens .extend (tokenizer .encode ("<|eot_id|>" , add_special_tokens = False ))
42
30
return tokens
43
31
44
32
45
33
def _format_dialog (dialog , tokenizer ):
46
34
tokens = []
47
- tokens .extend (tokenizer .encode ("<|begin_of_text|>" ))
35
+ tokens .extend (tokenizer .encode ("<|begin_of_text|>" , add_special_tokens = False ))
36
+ if dialog [0 ]["role" ] == "system" :
37
+ dialog [0 ]["content" ] = default_system_prompt [0 ]["content" ] + dialog [0 ]["content" ]
38
+ else :
39
+ dialog = default_system_prompt + dialog
48
40
for msg in dialog :
49
41
tokens .extend (_encode_message (msg , tokenizer ))
50
- tokens .extend (_encode_header ({"role" : "assistant" , "content" : "" }, tokenizer ))
51
42
return tokens
52
43
53
44
54
45
def _format_tokens_llama3 (dialogs , tokenizer ):
55
46
return [_format_dialog (dialog , tokenizer ) for dialog in dialogs ]
56
47
57
48
58
- def _format_tokens_llama2 (dialogs , tokenizer ):
59
- prompt_tokens = []
60
- for dialog in dialogs :
61
- if dialog [0 ]["role" ] == "system" :
62
- dialog = [
63
- {
64
- "role" : dialog [1 ]["role" ],
65
- "content" : B_SYS
66
- + dialog [0 ]["content" ]
67
- + E_SYS
68
- + dialog [1 ]["content" ],
69
- }
70
- ] + dialog [2 :]
71
- assert all ([msg ["role" ] == "user" for msg in dialog [::2 ]]) and all (
72
- [msg ["role" ] == "assistant" for msg in dialog [1 ::2 ]]
73
- ), (
74
- "model only supports 'system','user' and 'assistant' roles, "
75
- "starting with user and alternating (u/a/u/a/u...)"
76
- )
77
- """
78
- Please verify that your tokenizer support adding "[INST]", "[/INST]" to your inputs.
79
- Here, we are adding it manually.
80
- """
81
- dialog_tokens : List [int ] = sum (
82
- [
83
- tokenizer .encode (
84
- f"{ B_INST } { (prompt ['content' ]).strip ()} { E_INST } { (answer ['content' ]).strip ()} " ,
85
- )
86
- + [tokenizer .eos_token_id ]
87
- for prompt , answer in zip (dialog [::2 ], dialog [1 ::2 ])
88
- ],
89
- [],
90
- )
91
- assert (
92
- dialog [- 1 ]["role" ] == "user"
93
- ), f"Last message must be from user, got { dialog [- 1 ]['role' ]} "
94
- dialog_tokens += tokenizer .encode (
95
- f"{ B_INST } { (dialog [- 1 ]['content' ]).strip ()} { E_INST } " ,
96
- )
97
- prompt_tokens .append (dialog_tokens )
98
- return prompt_tokens
99
-
100
-
101
49
@pytest .mark .skip_missing_tokenizer
102
50
@patch ("chat_completion.AutoTokenizer" )
103
51
@patch ("chat_completion.load_model" )
104
52
def test_chat_completion (
105
53
load_model , tokenizer , setup_tokenizer , llama_tokenizer , llama_version
106
54
):
55
+ if "Llama-2" in llama_version :
56
+ pytest .skip ("skipping test for Llama-2" )
57
+
107
58
from chat_completion import main
108
59
109
60
setup_tokenizer (tokenizer )
110
- load_model .return_value .get_input_embeddings .return_value .weight .shape = [32000 if "Llama-2" in llama_version else 128256 ]
61
+ load_model .return_value .get_input_embeddings .return_value .weight .shape = [128256 ]
111
62
112
63
kwargs = {
113
64
"prompt_file" : (CHAT_COMPLETION_DIR / "chats.json" ).as_posix (),
@@ -116,13 +67,8 @@ def test_chat_completion(
116
67
main (llama_version , ** kwargs )
117
68
118
69
dialogs = read_dialogs_from_file (kwargs ["prompt_file" ])
119
- format_tokens = (
120
- _format_tokens_llama2
121
- if llama_version == "meta-llama/Llama-2-7b-hf"
122
- else _format_tokens_llama3
123
- )
124
70
125
- REF_RESULT = format_tokens (dialogs , llama_tokenizer [llama_version ])
71
+ REF_RESULT = _format_tokens_llama3 (dialogs , llama_tokenizer [llama_version ])
126
72
127
73
assert all (
128
74
(
0 commit comments