diff --git a/examples/slackbot/Notebook.ipynb b/examples/slackbot/Notebook.ipynb index 0708159eb..f758f46a7 100644 --- a/examples/slackbot/Notebook.ipynb +++ b/examples/slackbot/Notebook.ipynb @@ -27,109 +27,640 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "id": "61ea2e95-6d9d-4068-ab98-8cf94bc4d9d0", "metadata": {}, "outputs": [], "source": [ "from datetime import datetime, timedelta\n", - "from slack_sdk.socket_mode import SocketModeClient, SocketModeResponse\n", - "import sparrow_pi as kt\n", + "from slack_sdk.socket_mode import SocketModeClient\n", + "from slack_sdk.socket_mode.response import SocketModeResponse\n", + "import sparrow_py as kt\n", + "import pandas\n", "import openai\n", "import getpass\n", "import pyarrow\n", + "import datetime\n", "\n", "# Initialize Kaskada with a local execution context.\n", - "kt.init_session()\n", - "\n", - "# Initialize OpenAI\n", - "openai.api_key = getpass.getpass('OpenAI: API Key')\n", - "\n", - "# Initialize Slack\n", - "slack = SocketModeClient(\n", - " app_token=getpass.getpass('Slack: App Token'),\n", - " web_client=getpass.getpass('Slack: Bot Token'),\n", - ")" + "kt.init_session()" + ] + }, + { + "cell_type": "markdown", + "id": "0035f558-23bd-4b4d-95a0-ed5e8fece673", + "metadata": {}, + "source": [ + "## Fine-tune the model" ] }, { "cell_type": "markdown", - "id": "9b8a144d-8d79-4943-b99b-d3470ee96283", + "id": "6c3c5682-bfe0-44ca-9a5a-52a0da74e5de", "metadata": {}, "source": [ - "## Prompt Engineering" + "### Read Historical Messages" ] }, { "cell_type": "code", - "execution_count": null, - "id": "7e6fedb9", + "execution_count": 3, + "id": "9d224bec-e5a1-4c67-8764-e3dcdbc5e0ac", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
_time_subsort_key_hash_keysubtypetsusertextteamuser_team...reactionsthread_tsreply_countreply_users_countlatest_replyis_lockedsubscribedlast_readparent_user_idchannel
02023-07-25 19:42:13515750806798332339587generalmessage2023-07-25 19:42:13U05JQJJDJ6P<@U05JQJJDJ6P> has joined the channelNoneNone...NoneNoneNaNNaNNoneNoneNoneNoneNonegeneral
12023-07-25 19:42:14143094307063304068259randommessage2023-07-25 19:42:14U05JQJJDJ6P<@U05JQJJDJ6P> has joined the channelNoneNone...NoneNoneNaNNaNNoneNoneNoneNoneNonerandom
22023-07-25 19:44:2702954779196800164886demomessage2023-07-25 19:44:27U05JQJJDJ6P<@U05JQJJDJ6P> has joined the channelNoneNone...NoneNoneNaNNaNNoneNoneNoneNoneNonedemo
32023-07-26 08:29:35615750806798332339587generalmessage2023-07-26 08:29:35U05JQJJDJ6Pold message 1T05JA5XCR9DT05JA5XCR9D...NoneNoneNaNNaNNoneNoneNoneNoneNonegeneral
42023-07-26 08:29:37715750806798332339587generalmessage2023-07-26 08:29:37U05JQJJDJ6Pold message 2T05JA5XCR9DT05JA5XCR9D...NoneNoneNaNNaNNoneNoneNoneNoneNonegeneral
\n", + "

5 rows × 24 columns

\n", + "
" + ], + "text/plain": [ + " _time _subsort _key_hash _key subtype \\\n", + "0 2023-07-25 19:42:13 5 15750806798332339587 general message \n", + "1 2023-07-25 19:42:14 14 3094307063304068259 random message \n", + "2 2023-07-25 19:44:27 0 2954779196800164886 demo message \n", + "3 2023-07-26 08:29:35 6 15750806798332339587 general message \n", + "4 2023-07-26 08:29:37 7 15750806798332339587 general message \n", + "\n", + " ts user text \\\n", + "0 2023-07-25 19:42:13 U05JQJJDJ6P <@U05JQJJDJ6P> has joined the channel \n", + "1 2023-07-25 19:42:14 U05JQJJDJ6P <@U05JQJJDJ6P> has joined the channel \n", + "2 2023-07-25 19:44:27 U05JQJJDJ6P <@U05JQJJDJ6P> has joined the channel \n", + "3 2023-07-26 08:29:35 U05JQJJDJ6P old message 1 \n", + "4 2023-07-26 08:29:37 U05JQJJDJ6P old message 2 \n", + "\n", + " team user_team ... reactions thread_ts reply_count \\\n", + "0 None None ... None None NaN \n", + "1 None None ... None None NaN \n", + "2 None None ... None None NaN \n", + "3 T05JA5XCR9D T05JA5XCR9D ... None None NaN \n", + "4 T05JA5XCR9D T05JA5XCR9D ... None None NaN \n", + "\n", + " reply_users_count latest_reply is_locked subscribed last_read \\\n", + "0 NaN None None None None \n", + "1 NaN None None None None \n", + "2 NaN None None None None \n", + "3 NaN None None None None \n", + "4 NaN None None None None \n", + "\n", + " parent_user_id channel \n", + "0 None general \n", + "1 None random \n", + "2 None demo \n", + "3 None general \n", + "4 None general \n", + "\n", + "[5 rows x 24 columns]" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "def build_conversation(messages):\n", - " message_time = messages.col(\"ts\")\n", - " last_message_time = message_time.lag(1) # !!!\n", - " is_new_conversation = message_time.seconds_since(last_message_time) > 10 * 60\n", + "messages = kt.sources.ArrowSource(\n", + " data = pandas.read_parquet(\"./messages.parquet\"), \n", + " time_column_name = \"ts\", \n", + " key_column_name = \"channel\",\n", + ")\n", "\n", - " return messages \\\n", - " .select(\"user\", \"ts\", \"text\", \"reactions\") \\\n", - " .collect(window=kt.windows.Since(is_new_conversation), max=100)" + "messages.preview(5)" ] }, { "cell_type": "markdown", - "id": "9247233a", + "id": "5076d2bf-6830-460b-a9cb-948d8f106edc", "metadata": {}, - "source": [] + "source": [ + "### Build examples" + ] }, { "cell_type": "code", - "execution_count": null, - "id": "fdb2d959-d371-4026-9f8d-4ab26cfbf317", + "execution_count": 6, + "id": "af7d2a45-eb89-47ce-b471-a39ad8c7bbc7", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
_time_subsort_key_hash_keypromptcompletion
02023-07-25 19:43:13015750806798332339587general[{'ts': 1690314133000000000, 'user': 'U05JQJJD...[]
12023-07-25 19:43:1413094307063304068259random[{'ts': 1690314134000000000, 'user': 'U05JQJJD...[]
22023-07-25 19:45:2722954779196800164886demo[{'ts': 1690314267000000000, 'user': 'U05JQJJD...[]
32023-07-26 08:30:35315750806798332339587general[{'ts': 1690314133000000000, 'user': 'U05JQJJD...[]
42023-07-26 08:30:37415750806798332339587general[{'ts': 1690314133000000000, 'user': 'U05JQJJD...[]
52023-07-26 08:31:1052954779196800164886demo[{'ts': 1690314267000000000, 'user': 'U05JQJJD...[]
62023-07-26 08:31:1463094307063304068259random[{'ts': 1690314134000000000, 'user': 'U05JQJJD...[]
72023-07-26 08:31:4073094307063304068259random[{'ts': 1690314134000000000, 'user': 'U05JQJJD...[]
82023-07-26 08:31:4983094307063304068259random[{'ts': 1690314134000000000, 'user': 'U05JQJJD...[]
92023-07-26 08:31:5393094307063304068259random[{'ts': 1690314134000000000, 'user': 'U05JQJJD...[]
102023-07-26 08:31:57103094307063304068259random[{'ts': 1690314134000000000, 'user': 'U05JQJJD...[]
112023-07-26 09:56:07113094307063304068259random[{'ts': 1690314134000000000, 'user': 'U05JQJJD...[]
122023-07-26 09:56:351215750806798332339587general[{'ts': 1690314133000000000, 'user': 'U05JQJJD...[]
132023-07-26 09:56:52132954779196800164886demo[{'ts': 1690314267000000000, 'user': 'U05JQJJD...[]
142023-07-26 19:38:00142954779196800164886demo[{'ts': 1690314267000000000, 'user': 'U05JQJJD...[]
152023-07-26 19:38:001515750806798332339587general[{'ts': 1690314133000000000, 'user': 'U05JQJJD...[]
162023-07-26 19:38:00163094307063304068259random[{'ts': 1690314134000000000, 'user': 'U05JQJJD...[]
172023-07-26 19:44:07173094307063304068259random[{'ts': 1690314134000000000, 'user': 'U05JQJJD...[U05JQJJDJ6P]
182023-07-26 20:17:30182954779196800164886demo[{'ts': 1690314267000000000, 'user': 'U05JQJJD...[]
192023-07-26 20:17:301915750806798332339587general[{'ts': 1690314133000000000, 'user': 'U05JQJJD...[]
202023-07-26 20:17:30203094307063304068259random[{'ts': 1690314134000000000, 'user': 'U05JQJJD...[U05JQJJDJ6P]
212023-07-26 20:19:252115750806798332339587general[{'ts': 1690314133000000000, 'user': 'U05JQJJD...[U05JQJJDJ6P]
222023-07-26 20:32:04015750806798332339587general[{'ts': 1690314133000000000, 'user': 'U05JQJJD...[U05JQJJDJ6P]
232023-07-26 20:32:19115750806798332339587general[{'ts': 1690314133000000000, 'user': 'U05JQJJD...[U05JQJJDJ6P]
\n", + "
" + ], + "text/plain": [ + " _time _subsort _key_hash _key \\\n", + "0 2023-07-25 19:43:13 0 15750806798332339587 general \n", + "1 2023-07-25 19:43:14 1 3094307063304068259 random \n", + "2 2023-07-25 19:45:27 2 2954779196800164886 demo \n", + "3 2023-07-26 08:30:35 3 15750806798332339587 general \n", + "4 2023-07-26 08:30:37 4 15750806798332339587 general \n", + "5 2023-07-26 08:31:10 5 2954779196800164886 demo \n", + "6 2023-07-26 08:31:14 6 3094307063304068259 random \n", + "7 2023-07-26 08:31:40 7 3094307063304068259 random \n", + "8 2023-07-26 08:31:49 8 3094307063304068259 random \n", + "9 2023-07-26 08:31:53 9 3094307063304068259 random \n", + "10 2023-07-26 08:31:57 10 3094307063304068259 random \n", + "11 2023-07-26 09:56:07 11 3094307063304068259 random \n", + "12 2023-07-26 09:56:35 12 15750806798332339587 general \n", + "13 2023-07-26 09:56:52 13 2954779196800164886 demo \n", + "14 2023-07-26 19:38:00 14 2954779196800164886 demo \n", + "15 2023-07-26 19:38:00 15 15750806798332339587 general \n", + "16 2023-07-26 19:38:00 16 3094307063304068259 random \n", + "17 2023-07-26 19:44:07 17 3094307063304068259 random \n", + "18 2023-07-26 20:17:30 18 2954779196800164886 demo \n", + "19 2023-07-26 20:17:30 19 15750806798332339587 general \n", + "20 2023-07-26 20:17:30 20 3094307063304068259 random \n", + "21 2023-07-26 20:19:25 21 15750806798332339587 general \n", + "22 2023-07-26 20:32:04 0 15750806798332339587 general \n", + "23 2023-07-26 20:32:19 1 15750806798332339587 general \n", + "\n", + " prompt completion \n", + "0 [{'ts': 1690314133000000000, 'user': 'U05JQJJD... [] \n", + "1 [{'ts': 1690314134000000000, 'user': 'U05JQJJD... [] \n", + "2 [{'ts': 1690314267000000000, 'user': 'U05JQJJD... [] \n", + "3 [{'ts': 1690314133000000000, 'user': 'U05JQJJD... [] \n", + "4 [{'ts': 1690314133000000000, 'user': 'U05JQJJD... [] \n", + "5 [{'ts': 1690314267000000000, 'user': 'U05JQJJD... [] \n", + "6 [{'ts': 1690314134000000000, 'user': 'U05JQJJD... [] \n", + "7 [{'ts': 1690314134000000000, 'user': 'U05JQJJD... [] \n", + "8 [{'ts': 1690314134000000000, 'user': 'U05JQJJD... [] \n", + "9 [{'ts': 1690314134000000000, 'user': 'U05JQJJD... [] \n", + "10 [{'ts': 1690314134000000000, 'user': 'U05JQJJD... [] \n", + "11 [{'ts': 1690314134000000000, 'user': 'U05JQJJD... [] \n", + "12 [{'ts': 1690314133000000000, 'user': 'U05JQJJD... [] \n", + "13 [{'ts': 1690314267000000000, 'user': 'U05JQJJD... [] \n", + "14 [{'ts': 1690314267000000000, 'user': 'U05JQJJD... [] \n", + "15 [{'ts': 1690314133000000000, 'user': 'U05JQJJD... [] \n", + "16 [{'ts': 1690314134000000000, 'user': 'U05JQJJD... [] \n", + "17 [{'ts': 1690314134000000000, 'user': 'U05JQJJD... [U05JQJJDJ6P] \n", + "18 [{'ts': 1690314267000000000, 'user': 'U05JQJJD... [] \n", + "19 [{'ts': 1690314133000000000, 'user': 'U05JQJJD... [] \n", + "20 [{'ts': 1690314134000000000, 'user': 'U05JQJJD... [U05JQJJDJ6P] \n", + "21 [{'ts': 1690314133000000000, 'user': 'U05JQJJD... [U05JQJJDJ6P] \n", + "22 [{'ts': 1690314133000000000, 'user': 'U05JQJJD... [U05JQJJDJ6P] \n", + "23 [{'ts': 1690314133000000000, 'user': 'U05JQJJD... [U05JQJJDJ6P] " + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "def build_examples(messages):\n", - " duration = datetime.timedelta(minutes=5)\n", + "# Group messages by thread (if present) or channel\n", + "#messages = messages.with_key(kt.record({\n", + "# \"channel\": messages.col(\"channel\"),\n", + "# \"thread\": messages.col(\"thread_ts\"),\n", + "# }))\n", "\n", - " coverstation = build_conversation(messages)\n", - " shifted_coversation = coverstation.shift_by(duration) # !!!\n", "\n", - " reaction_users = coverstation.col(\"reactions\").col(\"name\").collect(kt.windows.Trailing(duration)).flatten() # !!!\n", - " participating_users = coverstation.col(\"user\").collect(kt.windows.Trailing(duration)) # !!!\n", - " engaged_users = kt.union(reaction_users, participating_users) # !!!\n", + "# Build the input prompt from recent messages\n", + "prompts = messages \\\n", + " .select(\"user\", \"ts\", \"text\", \"reactions\") \\\n", + " .collect(max=20)\n", "\n", - " return kt.record({ \"prompt\": shifted_coversation, \"completion\": engaged_users}) \\\n", - " .filter(shifted_coversation.is_not_null())" + "\n", + "# Build the completion from users who engage after the prompt\n", + "duration = datetime.timedelta(minutes=1)\n", + "\n", + "shifted_prompts = prompts.shift_by(duration)\n", + "\n", + "reaction_users = messages.collect(max=100).col(\"reactions\").flatten().col(\"users\").flatten().last()\n", + "#reaction_users = messages.collect(kt.Trailing(duration), max=100).col(\"reactions\").flatten().col(\"users\").flatten().last()\n", + "#participating_users = prompts.col(\"user\").collect(max=100) #kt.windows.Trailing(duration))\n", + "engaged_users = reaction_users #kt.union(reaction_users, participating_users)\n", + "\n", + "examples = kt.record({\"prompt\": shifted_prompts, \"completion\": engaged_users}) \\\n", + " .filter(shifted_prompts.is_not_null())\n", + "\n", + "\n", + "prompts.preview(5)\n", + "examples.preview(100)" ] }, { "cell_type": "markdown", - "id": "0035f558-23bd-4b4d-95a0-ed5e8fece673", + "id": "fed3b924-e7de-414b-bba5-b119e40921f0", "metadata": {}, "source": [ - "## Fine-tune the model" + "## Fine-tune a model" ] }, { - "cell_type": "code", - "execution_count": null, - "id": "af7d2a45-eb89-47ce-b471-a39ad8c7bbc7", + "cell_type": "markdown", + "id": "e78fa9bd-9c40-403d-a7ee-a15620a88418", "metadata": {}, - "outputs": [], "source": [ - "import pandas\n", - "import sparrow_pi.sources as sources\n", - "\n", - "messages = kt.sources.Parquet(\"./messages.parquet\", time = \"ts\", entity = \"channel\")\n", - "messages = messages.with_key(kt.record({ # !!!\n", - " \"channel\": messages.col(\"channel\"),\n", - " \"thread\": messages.col(\"thread_ts\"),\n", - " }))\n", - "examples = build_examples(messages)\n", - "\n", - "examples_df = examples.run().to_pandas()" + "### Create training dataset" ] }, { @@ -141,22 +672,36 @@ "source": [ "from sklearn import preprocessing\n", "\n", + "# Extract examples from historical data\n", + "examples_df = examples.run().to_pandas().drop([\"_time\", \"_subsort\", \"_key_hash\", \"_key\"], axis=1)\n", + "\n", + "\n", + "# Encode user ID labels\n", "le = preprocessing.LabelEncoder()\n", "le.fit(examples_df.completion.explode())\n", "\n", + "\n", "# Format for the OpenAI API\n", "def format_prompt(prompt):\n", " return \"start -> \" + \"\\n\\n\".join([f' {msg[\"user\"]} --> {msg[\"text\"]} ' for msg in prompt]) + \"\\n\\n###\\n\\n\"\n", "examples_df.prompt = examples_df.prompt.apply(format_prompt)\n", - "\n", "def format_completion(completion):\n", - " return \" \" + (\" \".join([le.transform(u) for u in completion]) if len(completion) > 0 else \"nil\") + \" end\"\n", + " return \" \" + (\" \".join(le.transform(completion).astype(str)) if len(completion) > 0 else \"nil\") + \" end\"\n", "examples_df.completion = examples_df.completion.apply(format_completion)\n", "\n", + "\n", "# Write examples to file\n", "examples_df.to_json(\"examples.jsonl\", orient='records', lines=True)" ] }, + { + "cell_type": "markdown", + "id": "cfc60311-5ca1-49f3-8e35-4070174e0258", + "metadata": {}, + "source": [ + "### Upload to OpenAI" + ] + }, { "cell_type": "code", "execution_count": null, @@ -169,12 +714,23 @@ "from types import SimpleNamespace\n", "from openai import cli\n", "\n", - "# verifiy data format, split for training & validation\n", + "# Initialize OpenAI\n", + "openai.api_key = getpass.getpass('OpenAI: API Key')\n", + "\n", + "# Verifiy data format, split for training & validation, upload to OpenAI\n", "args = SimpleNamespace(file='./examples.jsonl', quiet=True)\n", "cli.FineTune.prepare_data(args)\n", "training_id = cli.FineTune._get_or_upload('./examples_prepared_train.jsonl', True)" ] }, + { + "cell_type": "markdown", + "id": "f0808e6e-8239-4487-b22b-14b5a00948c6", + "metadata": {}, + "source": [ + "### Create the training job" + ] + }, { "cell_type": "code", "execution_count": null, @@ -193,103 +749,6 @@ ")\n", "print(f'Fine-tuning model with job ID: \"{resp[\"id\"]}\"')" ] - }, - { - "cell_type": "markdown", - "id": "b3e29109-cc00-4bf5-ba23-069e8db1f179", - "metadata": { - "jp-MarkdownHeadingCollapsed": true, - "tags": [] - }, - "source": [ - "## Notify users of conversations they need to know about" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "540afff7-4ebc-427f-8205-1ed145e0c59a", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "import json, math\n", - "\n", - "min_prob_for_response = 0.50\n", - "\n", - "# Receive Slack messages in real-time\n", - "live_messages = kt.sources.ArrowSource(entity_column=\"channel\", time_column=\"ts\")\n", - "\n", - "# Receive messages from Slack\n", - "def handle_message(client, req):\n", - " # Acknowledge the message back to Slack\n", - " client.send_socket_mode_response(SocketModeResponse(envelope_id=req.envelope_id))\n", - " \n", - " if req.type == \"events_api\" and \"event\" in req.payload:\n", - " e = req.payload[\"event\"]\n", - " if \"previous_message\" in e or e[\"type\"] == \"reaction_added\":\n", - " return\n", - " # send message events to Kaskada\n", - " live_messages.add_event(pyarrow.json.read_json(e))\n", - "\n", - "slack.socket_mode_request_listeners.append(handle_message)\n", - "slack.connect()\n", - "\n", - "# Handle messages in realtime\n", - "# A \"conversation\" is a list of messages\n", - "for conversation in build_conversation(live_messages).start().to_generator():\n", - " if len(conversation) == 0:\n", - " continue\n", - " \n", - " # Ask the model who should be notified\n", - " res = openai.Completion.create(\n", - " model=\"davinci:ft-personal:coversation-users-full-kaskada-2023-08-05-14-25-30\", \n", - " prompt=format_prompt(conversation),\n", - " logprobs=5,\n", - " max_tokens=1,\n", - " stop=\" end\",\n", - " temperature=0,\n", - " )\n", - "\n", - " users = []\n", - " logprobs = res[\"choices\"][0][\"logprobs\"][\"top_logprobs\"][0]\n", - " for user in logprobs:\n", - " if math.exp(logprobs[user]) > min_prob_for_response:\n", - " user = users.strip()\n", - " # if users include `nil`, stop processing\n", - " if user == \"nil\":\n", - " users = []\n", - " break\n", - " users.append(user)\n", - "\n", - " # alert on most recent message in conversation\n", - " msg = conversation.pop()\n", - " \n", - " # Send notification to users\n", - " for user in users:\n", - " user_id = le.inverse_transform(user)\n", - "\n", - " # get user channel for slackbot\n", - " app = slack.web_client.users_conversations(\n", - " types=\"im\",\n", - " user=user_id,\n", - " )\n", - " \n", - " # confirm user has slackbot installed\n", - " if len(app[\"channels\"]) == 0:\n", - " continue\n", - "\n", - " link = slack.web_client.chat_getPermalink(\n", - " channel=msg[\"channel\"],\n", - " message_ts=msg[\"ts\"],\n", - " )[\"permalink\"]\n", - " \n", - " slack.web_client.chat_postMessage(\n", - " channel=app[\"channels\"][0][\"id\"],\n", - " text=f'You may be interested in this converstation: <{link}|{msg[\"text\"]}>'\n", - " )" - ] } ], "metadata": { @@ -308,7 +767,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.9" + "version": "3.11.3" } }, "nbformat": 4, diff --git a/examples/slackbot/README.md b/examples/slackbot/README.md index ee26b707b..08c68b778 100644 --- a/examples/slackbot/README.md +++ b/examples/slackbot/README.md @@ -37,15 +37,5 @@ copy( union_by_name=true) ) ) to 'messages.parquet' (FORMAT PARQUET) -; - - select - from ( - select * from read_json_auto('data/iloveai-initial-export/*/*.json', - format='array', - filename=true, - union_by_name=true) - ) - limit 10 ; ``` \ No newline at end of file diff --git a/examples/slackbot/messages.parquet b/examples/slackbot/messages.parquet index a3ad8a5ac..170705f99 100644 Binary files a/examples/slackbot/messages.parquet and b/examples/slackbot/messages.parquet differ diff --git a/examples/slackbot/run.py b/examples/slackbot/run.py new file mode 100644 index 000000000..a2743cc21 --- /dev/null +++ b/examples/slackbot/run.py @@ -0,0 +1,128 @@ +import json, math, openai, os +from slack_sdk.web import WebClient +from slack_sdk.socket_mode import SocketModeClient +from slack_sdk.socket_mode.response import SocketModeResponse + +output_map = {} + +with open('./user_output_map.json', 'r') as file: + output_map = json.load(file) + +print(f'Loaded output map: {output_map}') + + +# Initialize OpenAI +openai.api_key = os.environ.get("OPEN_AI_KEY") + +# Initialize Slack +slack = SocketModeClient( + app_token=os.environ.get("SLACK_APP_TOKEN"), + web_client=WebClient(token=os.environ.get("SLACK_BOT_TOKEN")) +) + +min_prob_for_response = 0.50 + +# Format for the OpenAI API +def format_prompt(prompt): + return "start -> " + "\n\n".join([f' {msg["user"]} --> {msg["text"]} ' for msg in prompt]) + "\n\n###\n\n" + +def handle_conversation(conversation): + if len(conversation) == 0: + return + + print(f'Starting prediction on conversation: {conversation[0]["text"]}') + + # Ask the model who should be notified + res = openai.Completion.create( + model="davinci:ft-personal:coversation-users-full-kaskada-2023-08-05-14-25-30", + prompt=format_prompt(conversation), + max_tokens=1, + stop=" end", + temperature=0, + logprobs=5, + ) + + users = [] + logprobs = res["choices"][0]["logprobs"]["top_logprobs"][0] + + print(f'Recieved log probs: {logprobs}') + for user in logprobs: + if math.exp(logprobs[user]) > min_prob_for_response: + # if `nil` user is an option, stop processing + user = user.strip() + if user == "nil": + users = [] + print('Found nil, stopping.') + break + users.append(user) + + print(f'Found users to alert: {users}') + # alert on most recent message in conversation + msg = conversation.pop() + + # Send notification to users + for user_num in users: + if user_num not in output_map: + print(f'User: {user_num} not in output_map, stopping.') + else: + user_id = output_map[user_num] + + print(f'Found user {user_num} in output map: {user_id}') + + link = slack.web_client.chat_getPermalink( + channel=msg["channel"], + message_ts=msg["ts"], + )["permalink"] + + print(f'Got message link: {link}') + + res = slack.web_client.users_conversations( + types="im", + user=user_id, + ) + if len(res["channels"]) == 0: + print(f'User: {user} hasn\'t installed the slackbot yet') + else: + app_channel = res["channels"][0]["id"] + print(f'Got user\'s bot channel id: {app_channel}') + + slack.web_client.chat_postMessage( + channel=app_channel, + text=f'You may be interested in this converstation: <{link}|{msg["text"]}>' + ) + + print(f'Posted alert message') + +# Receive Slack messages in real-time +#live_messages = kt.sources.read_stream(entity_column="channel", time_column="ts") + +# Receive messages from Slack +def handle_message(client, req): + # Acknowledge the message back to Slack + client.send_socket_mode_response(SocketModeResponse(envelope_id=req.envelope_id)) + + # Deliver the message to Kaskada + #live_messages.add_event(pyarrow.json.read_json(req.payload)) + + if req.type == "events_api" and "event" in req.payload: + e = req.payload["event"] + + # ignore message edit, delete, reaction events + if "previous_message" in e or e["type"] == "reaction_added": + return + + # make single-message conversations for now + handle_conversation([e]) + + +# Handle messages in realtime +# A "conversation" is a list of messages +#for conversation in build_conversation(live_messages).start().to_generator(): + + +slack.socket_mode_request_listeners.append(handle_message) +slack.connect() + +# Just not to stop this process +from threading import Event +Event().wait() diff --git a/examples/slackbot/slackbot.py b/examples/slackbot/slackbot.py index 094b6d68d..ba286c6f2 100644 --- a/examples/slackbot/slackbot.py +++ b/examples/slackbot/slackbot.py @@ -1,123 +1,100 @@ -import json, math, openai, os, pyarrow -from slack_sdk.web import WebClient -from slack_sdk.socket_mode import SocketModeClient +import json, math, datetime, openai, os, pyarrow, pandas, asyncio +from slack_sdk.web.async_client import AsyncWebClient +from slack_sdk.socket_mode.aiohttp import SocketModeClient from slack_sdk.socket_mode.response import SocketModeResponse -import sparrow_pi as kt +import sparrow_py as kt -def build_conversation(messages): - message_time = messages.col("ts") - last_message_time = message_time.lag(1) # !!! - is_new_conversation = message_time.seconds_since(last_message_time) > 10 * 60 - - return messages \ - .select("user", "ts", "text", "reactions") \ - .collect(window=kt.windows.Since(is_new_conversation), max=100) - -def build_examples(messages): - duration = kt.minutes(5) # !!! - - coverstation = build_conversation(messages) - shifted_coversation = coverstation.shift_by(duration) # !!! - - reaction_users = coverstation.col("reactions").col("name").collect(kt.windows.Trailing(duration)).flatten() # !!! - participating_users = coverstation.col("user").collect(kt.windows.Trailing(duration)) # !!! - engaged_users = kt.union(reaction_users, participating_users) # !!! - - return kt.record({ "prompt": shifted_coversation, "completion": engaged_users}) \ - .filter(shifted_coversation.is_not_null()) - -def format_prompt(prompt): - return "start -> " + "\n\n".join([f' {msg["user"]} --> {msg["text"]} ' for msg in prompt]) + "\n\n###\n\n" - -def main(): +async def main(): + start = datetime.datetime.now() + + # Load user label map output_map = {} - with open('./user_output_map.json', 'r') as file: output_map = json.load(file) - print(f'Loaded output map: {output_map}') - - # Initialize Kaskada with a local execution context. + # Initialize clients kt.init_session() - - # Initialize OpenAI openai.api_key = os.environ.get("OPEN_AI_KEY") - - # Initialize Slack slack = SocketModeClient( app_token=os.environ.get("SLACK_APP_TOKEN"), - web_client=WebClient(token=os.environ.get("SLACK_BOT_TOKEN")) + web_client=AsyncWebClient(token=os.environ.get("SLACK_BOT_TOKEN")) ) - min_prob_for_response = 0.50 - # Receive Slack messages in real-time - live_messages = kt.sources.read_stream(entity_column="channel", time_column="ts") + + # Backfill state with historical data + historical_data = pandas.read_parquet("./messages.parquet")[:1] + schema = pyarrow.Schema.from_pandas(historical_data) + messages = kt.sources.ArrowSource( + data = historical_data, + time_column_name = "ts", + key_column_name = "channel", + ) + - # Receive messages from Slack - def handle_message(client, req): + + # Receive Slack messages in real-time + async def handle_message(client, req): # Acknowledge the message back to Slack - client.send_socket_mode_response(SocketModeResponse(envelope_id=req.envelope_id)) + await client.send_socket_mode_response(SocketModeResponse(envelope_id=req.envelope_id)) if req.type == "events_api" and "event" in req.payload: - e = req.payload["event"] - - print(f'Received event from slack websocket: {e}') - # ignore message edit, delete, reaction events - if "previous_message" in e or e["type"] == "reaction_added": + if "previous_message" in req.payload["event"] or req.payload["event"]["type"] == "reaction_added": return - print(f'Sending message event to kaskada: {e}') - - # Deliver the message to Kaskada - live_messages.add_event(pyarrow.json.read_json(e)) + req.payload["event"]["ts"] = datetime.datetime.fromtimestamp(float(req.payload["event"]["ts"])) + del req.payload["event"]["team"] + data = pyarrow.RecordBatch.from_pylist([req.payload["event"]], schema=schema) + messages.add_data(data) slack.socket_mode_request_listeners.append(handle_message) - slack.connect() + await slack.connect() + + + + # Compute conversations from individual messages + conversations = messages.with_key(kt.record({ + "channel": messages.col("channel"), + "thread": messages.col("thread_ts"), + })) \ + .select("user", "ts", "text", "reactions") \ + .collect(max=3) + - # Handle messages in realtime - # A "conversation" is a list of messages - for conversation in build_conversation(live_messages).start().to_generator(): - if len(conversation) == 0: + + # Handle each conversation as it occurs + async for row in conversations.run(materialize=True).iter_rows_async(): + conversation = row[" result"] + if len(conversation) == 0 or row["_time"] < start: continue print(f'Starting completion on conversation with first message text: {conversation[0]["text"]}') - prompt = format_prompt(conversation) - - print(f'Using prompt: {prompt}') - # Ask the model who should be notified res = openai.Completion.create( model="davinci:ft-personal:coversation-users-full-kaskada-2023-08-05-14-25-30", - prompt=prompt, + prompt="start -> " + "\n\n".join([f' {msg["user"]} --> {msg["text"]} ' for msg in conversation]) + "\n\n###\n\n", logprobs=5, max_tokens=1, stop=" end", - temperature=0, + temperature=1, ) - print(f'Received completion response: {res}') - + msg = conversation.pop(0) users = [] logprobs = res["choices"][0]["logprobs"]["top_logprobs"][0] - - print(f'Found logprobs: {logprobs}') + print(f"Predicted interest logprobs: {logprobs}") + print(f"Notifying users: {users}") for user in logprobs: - if math.exp(logprobs[user]) > min_prob_for_response: - user = users.strip() + if math.exp(logprobs[user]) > 0.30: + user = user.strip() # if users include `nil`, stop processing if user == "nil": users = [] break users.append(user) - - print(f'Found users to alert: {users}') - - # alert on most recent message in conversation - msg = conversation.pop() - + # Send notification to users for user_num in users: if user_num not in output_map: @@ -153,5 +130,7 @@ def handle_message(client, req): print(f'Posted alert message') + print("Done") + if __name__ == "__main__": - main() \ No newline at end of file + asyncio.run(main()) \ No newline at end of file