|
38 | 38 | },
|
39 | 39 | "id": "49abde692940b09e"
|
40 | 40 | },
|
| 41 | + { |
| 42 | + "cell_type": "markdown", |
| 43 | + "source": [ |
| 44 | + "## Single-round chat completion" |
| 45 | + ], |
| 46 | + "metadata": { |
| 47 | + "collapsed": false |
| 48 | + }, |
| 49 | + "id": "84b663418e3e3b19" |
| 50 | + }, |
41 | 51 | {
|
42 | 52 | "cell_type": "code",
|
43 | 53 | "execution_count": null,
|
44 | 54 | "outputs": [],
|
45 | 55 | "source": [
|
46 |
| - "# normal \n", |
47 | 56 | "chat_completion_result = taskingai.inference.chat_completion(\n",
|
48 | 57 | " model_id=model_id,\n",
|
49 | 58 | " messages=[\n",
|
|
58 | 67 | },
|
59 | 68 | "id": "43dcc632665f0de4"
|
60 | 69 | },
|
| 70 | + { |
| 71 | + "cell_type": "markdown", |
| 72 | + "source": [ |
| 73 | + "## Multi-round chat completion" |
| 74 | + ], |
| 75 | + "metadata": { |
| 76 | + "collapsed": false |
| 77 | + }, |
| 78 | + "id": "9f84e86d19409580" |
| 79 | + }, |
61 | 80 | {
|
62 | 81 | "cell_type": "code",
|
63 | 82 | "execution_count": null,
|
64 | 83 | "outputs": [],
|
65 | 84 | "source": [
|
66 |
| - "# multi round chat completion\n", |
67 | 85 | "chat_completion_result = taskingai.inference.chat_completion(\n",
|
68 | 86 | " model_id=model_id,\n",
|
69 | 87 | " messages=[\n",
|
|
87 | 105 | "execution_count": null,
|
88 | 106 | "outputs": [],
|
89 | 107 | "source": [
|
90 |
| - "# config max tokens\n", |
| 108 | + "# Add max tokens configs\n", |
91 | 109 | "chat_completion_result = taskingai.inference.chat_completion(\n",
|
92 | 110 | " model_id=model_id,\n",
|
93 | 111 | " messages=[\n",
|
|
109 | 127 | },
|
110 | 128 | "id": "f7c1b8be2579d9e0"
|
111 | 129 | },
|
| 130 | + { |
| 131 | + "cell_type": "markdown", |
| 132 | + "source": [ |
| 133 | + "## Function call" |
| 134 | + ], |
| 135 | + "metadata": { |
| 136 | + "collapsed": false |
| 137 | + }, |
| 138 | + "id": "c615ece16c777029" |
| 139 | + }, |
112 | 140 | {
|
113 | 141 | "cell_type": "code",
|
114 | 142 | "execution_count": null,
|
115 | 143 | "outputs": [],
|
116 | 144 | "source": [
|
117 |
| - "# function call\n", |
| 145 | + "# function definition\n", |
118 | 146 | "function = Function(\n",
|
119 | 147 | " name=\"plus_a_and_b\",\n",
|
120 | 148 | " description=\"Sum up a and b and return the result\",\n",
|
|
132 | 160 | " },\n",
|
133 | 161 | " \"required\": [\"a\", \"b\"]\n",
|
134 | 162 | " },\n",
|
135 |
| - ")\n", |
| 163 | + ")" |
| 164 | + ], |
| 165 | + "metadata": { |
| 166 | + "collapsed": false |
| 167 | + }, |
| 168 | + "id": "2645bdc3df011e7d" |
| 169 | + }, |
| 170 | + { |
| 171 | + "cell_type": "code", |
| 172 | + "execution_count": null, |
| 173 | + "outputs": [], |
| 174 | + "source": [ |
| 175 | + "# chat completion with the function call\n", |
136 | 176 | "chat_completion_result = taskingai.inference.chat_completion(\n",
|
137 | 177 | " model_id=model_id,\n",
|
138 | 178 | " messages=[\n",
|
139 |
| - " SystemMessage(\"You are a professional assistant.\"),\n", |
140 | 179 | " UserMessage(\"What is the result of 112 plus 22?\"),\n",
|
141 | 180 | " ],\n",
|
142 | 181 | " functions=[function]\n",
|
143 | 182 | ")\n",
|
144 |
| - "print(f\"chat_completion_result = {chat_completion_result}\")\n", |
145 |
| - "\n", |
146 |
| - "assistant_function_call_message = chat_completion_result.message\n", |
147 |
| - "fucntion_name = assistant_function_call_message.function_call.name\n", |
148 |
| - "argument_content = json.dumps(assistant_function_call_message.function_call.arguments)\n", |
149 |
| - "print(f\"function name: {fucntion_name}, argument content: {argument_content}\")" |
| 183 | + "function_call_assistant_message = chat_completion_result.message\n", |
| 184 | + "print(f\"function_call_assistant_message = {function_call_assistant_message}\")" |
150 | 185 | ],
|
151 | 186 | "metadata": {
|
152 | 187 | "collapsed": false
|
153 | 188 | },
|
154 |
| - "id": "2645bdc3df011e7d" |
| 189 | + "id": "850adc819aa228fc" |
155 | 190 | },
|
156 | 191 | {
|
157 |
| - "cell_type": "markdown", |
158 |
| - "source": [], |
| 192 | + "cell_type": "code", |
| 193 | + "execution_count": null, |
| 194 | + "outputs": [], |
| 195 | + "source": [ |
| 196 | + "# get the function call result\n", |
| 197 | + "def plus_a_and_b(a, b):\n", |
| 198 | + " return a + b\n", |
| 199 | + "\n", |
| 200 | + "arguments = function_call_assistant_message.function_call.arguments\n", |
| 201 | + "function_call_result = plus_a_and_b(**arguments)\n", |
| 202 | + "print(f\"function_call_result = {function_call_result}\")" |
| 203 | + ], |
159 | 204 | "metadata": {
|
160 | 205 | "collapsed": false
|
161 | 206 | },
|
162 |
| - "id": "ed6957f0c380ba9f" |
| 207 | + "id": "45787662d2148352" |
163 | 208 | },
|
164 | 209 | {
|
165 | 210 | "cell_type": "code",
|
166 | 211 | "execution_count": null,
|
167 | 212 | "outputs": [],
|
168 | 213 | "source": [
|
169 |
| - "# add function message\n", |
| 214 | + "# chat completion with the function result\n", |
170 | 215 | "chat_completion_result = taskingai.inference.chat_completion(\n",
|
171 | 216 | " model_id=model_id,\n",
|
172 | 217 | " messages=[\n",
|
173 |
| - " SystemMessage(\"You are a professional assistant.\"),\n", |
174 | 218 | " UserMessage(\"What is the result of 112 plus 22?\"),\n",
|
175 |
| - " assistant_function_call_message,\n", |
176 |
| - " FunctionMessage(name=fucntion_name, content=\"144\")\n", |
| 219 | + " function_call_assistant_message,\n", |
| 220 | + " FunctionMessage(name=\"plus_a_and_b\", content=str(function_call_result))\n", |
177 | 221 | " ],\n",
|
178 | 222 | " functions=[function]\n",
|
179 | 223 | ")\n",
|
|
184 | 228 | },
|
185 | 229 | "id": "9df9a8b9eafa17d9"
|
186 | 230 | },
|
| 231 | + { |
| 232 | + "cell_type": "markdown", |
| 233 | + "source": [ |
| 234 | + "## Stream mode" |
| 235 | + ], |
| 236 | + "metadata": { |
| 237 | + "collapsed": false |
| 238 | + }, |
| 239 | + "id": "a64da98251c5d3c5" |
| 240 | + }, |
187 | 241 | {
|
188 | 242 | "cell_type": "code",
|
189 | 243 | "execution_count": null,
|
|
0 commit comments