diff --git a/examples/slackbot/Notebook.ipynb b/examples/slackbot/Notebook.ipynb
index 0708159eb..f758f46a7 100644
--- a/examples/slackbot/Notebook.ipynb
+++ b/examples/slackbot/Notebook.ipynb
@@ -27,109 +27,640 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 1,
"id": "61ea2e95-6d9d-4068-ab98-8cf94bc4d9d0",
"metadata": {},
"outputs": [],
"source": [
"from datetime import datetime, timedelta\n",
- "from slack_sdk.socket_mode import SocketModeClient, SocketModeResponse\n",
- "import sparrow_pi as kt\n",
+ "from slack_sdk.socket_mode import SocketModeClient\n",
+ "from slack_sdk.socket_mode.response import SocketModeResponse\n",
+ "import sparrow_py as kt\n",
+ "import pandas\n",
"import openai\n",
"import getpass\n",
"import pyarrow\n",
+ "import datetime\n",
"\n",
"# Initialize Kaskada with a local execution context.\n",
- "kt.init_session()\n",
- "\n",
- "# Initialize OpenAI\n",
- "openai.api_key = getpass.getpass('OpenAI: API Key')\n",
- "\n",
- "# Initialize Slack\n",
- "slack = SocketModeClient(\n",
- " app_token=getpass.getpass('Slack: App Token'),\n",
- " web_client=getpass.getpass('Slack: Bot Token'),\n",
- ")"
+ "kt.init_session()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "0035f558-23bd-4b4d-95a0-ed5e8fece673",
+ "metadata": {},
+ "source": [
+ "## Fine-tune the model"
]
},
{
"cell_type": "markdown",
- "id": "9b8a144d-8d79-4943-b99b-d3470ee96283",
+ "id": "6c3c5682-bfe0-44ca-9a5a-52a0da74e5de",
"metadata": {},
"source": [
- "## Prompt Engineering"
+ "### Read Historical Messages"
]
},
{
"cell_type": "code",
- "execution_count": null,
- "id": "7e6fedb9",
+ "execution_count": 3,
+ "id": "9d224bec-e5a1-4c67-8764-e3dcdbc5e0ac",
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "
\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " _time | \n",
+ " _subsort | \n",
+ " _key_hash | \n",
+ " _key | \n",
+ " subtype | \n",
+ " ts | \n",
+ " user | \n",
+ " text | \n",
+ " team | \n",
+ " user_team | \n",
+ " ... | \n",
+ " reactions | \n",
+ " thread_ts | \n",
+ " reply_count | \n",
+ " reply_users_count | \n",
+ " latest_reply | \n",
+ " is_locked | \n",
+ " subscribed | \n",
+ " last_read | \n",
+ " parent_user_id | \n",
+ " channel | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " | 0 | \n",
+ " 2023-07-25 19:42:13 | \n",
+ " 5 | \n",
+ " 15750806798332339587 | \n",
+ " general | \n",
+ " message | \n",
+ " 2023-07-25 19:42:13 | \n",
+ " U05JQJJDJ6P | \n",
+ " <@U05JQJJDJ6P> has joined the channel | \n",
+ " None | \n",
+ " None | \n",
+ " ... | \n",
+ " None | \n",
+ " None | \n",
+ " NaN | \n",
+ " NaN | \n",
+ " None | \n",
+ " None | \n",
+ " None | \n",
+ " None | \n",
+ " None | \n",
+ " general | \n",
+ "
\n",
+ " \n",
+ " | 1 | \n",
+ " 2023-07-25 19:42:14 | \n",
+ " 14 | \n",
+ " 3094307063304068259 | \n",
+ " random | \n",
+ " message | \n",
+ " 2023-07-25 19:42:14 | \n",
+ " U05JQJJDJ6P | \n",
+ " <@U05JQJJDJ6P> has joined the channel | \n",
+ " None | \n",
+ " None | \n",
+ " ... | \n",
+ " None | \n",
+ " None | \n",
+ " NaN | \n",
+ " NaN | \n",
+ " None | \n",
+ " None | \n",
+ " None | \n",
+ " None | \n",
+ " None | \n",
+ " random | \n",
+ "
\n",
+ " \n",
+ " | 2 | \n",
+ " 2023-07-25 19:44:27 | \n",
+ " 0 | \n",
+ " 2954779196800164886 | \n",
+ " demo | \n",
+ " message | \n",
+ " 2023-07-25 19:44:27 | \n",
+ " U05JQJJDJ6P | \n",
+ " <@U05JQJJDJ6P> has joined the channel | \n",
+ " None | \n",
+ " None | \n",
+ " ... | \n",
+ " None | \n",
+ " None | \n",
+ " NaN | \n",
+ " NaN | \n",
+ " None | \n",
+ " None | \n",
+ " None | \n",
+ " None | \n",
+ " None | \n",
+ " demo | \n",
+ "
\n",
+ " \n",
+ " | 3 | \n",
+ " 2023-07-26 08:29:35 | \n",
+ " 6 | \n",
+ " 15750806798332339587 | \n",
+ " general | \n",
+ " message | \n",
+ " 2023-07-26 08:29:35 | \n",
+ " U05JQJJDJ6P | \n",
+ " old message 1 | \n",
+ " T05JA5XCR9D | \n",
+ " T05JA5XCR9D | \n",
+ " ... | \n",
+ " None | \n",
+ " None | \n",
+ " NaN | \n",
+ " NaN | \n",
+ " None | \n",
+ " None | \n",
+ " None | \n",
+ " None | \n",
+ " None | \n",
+ " general | \n",
+ "
\n",
+ " \n",
+ " | 4 | \n",
+ " 2023-07-26 08:29:37 | \n",
+ " 7 | \n",
+ " 15750806798332339587 | \n",
+ " general | \n",
+ " message | \n",
+ " 2023-07-26 08:29:37 | \n",
+ " U05JQJJDJ6P | \n",
+ " old message 2 | \n",
+ " T05JA5XCR9D | \n",
+ " T05JA5XCR9D | \n",
+ " ... | \n",
+ " None | \n",
+ " None | \n",
+ " NaN | \n",
+ " NaN | \n",
+ " None | \n",
+ " None | \n",
+ " None | \n",
+ " None | \n",
+ " None | \n",
+ " general | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
5 rows × 24 columns
\n",
+ "
"
+ ],
+ "text/plain": [
+ " _time _subsort _key_hash _key subtype \\\n",
+ "0 2023-07-25 19:42:13 5 15750806798332339587 general message \n",
+ "1 2023-07-25 19:42:14 14 3094307063304068259 random message \n",
+ "2 2023-07-25 19:44:27 0 2954779196800164886 demo message \n",
+ "3 2023-07-26 08:29:35 6 15750806798332339587 general message \n",
+ "4 2023-07-26 08:29:37 7 15750806798332339587 general message \n",
+ "\n",
+ " ts user text \\\n",
+ "0 2023-07-25 19:42:13 U05JQJJDJ6P <@U05JQJJDJ6P> has joined the channel \n",
+ "1 2023-07-25 19:42:14 U05JQJJDJ6P <@U05JQJJDJ6P> has joined the channel \n",
+ "2 2023-07-25 19:44:27 U05JQJJDJ6P <@U05JQJJDJ6P> has joined the channel \n",
+ "3 2023-07-26 08:29:35 U05JQJJDJ6P old message 1 \n",
+ "4 2023-07-26 08:29:37 U05JQJJDJ6P old message 2 \n",
+ "\n",
+ " team user_team ... reactions thread_ts reply_count \\\n",
+ "0 None None ... None None NaN \n",
+ "1 None None ... None None NaN \n",
+ "2 None None ... None None NaN \n",
+ "3 T05JA5XCR9D T05JA5XCR9D ... None None NaN \n",
+ "4 T05JA5XCR9D T05JA5XCR9D ... None None NaN \n",
+ "\n",
+ " reply_users_count latest_reply is_locked subscribed last_read \\\n",
+ "0 NaN None None None None \n",
+ "1 NaN None None None None \n",
+ "2 NaN None None None None \n",
+ "3 NaN None None None None \n",
+ "4 NaN None None None None \n",
+ "\n",
+ " parent_user_id channel \n",
+ "0 None general \n",
+ "1 None random \n",
+ "2 None demo \n",
+ "3 None general \n",
+ "4 None general \n",
+ "\n",
+ "[5 rows x 24 columns]"
+ ]
+ },
+ "execution_count": 3,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
"source": [
- "def build_conversation(messages):\n",
- " message_time = messages.col(\"ts\")\n",
- " last_message_time = message_time.lag(1) # !!!\n",
- " is_new_conversation = message_time.seconds_since(last_message_time) > 10 * 60\n",
+ "messages = kt.sources.ArrowSource(\n",
+ " data = pandas.read_parquet(\"./messages.parquet\"), \n",
+ " time_column_name = \"ts\", \n",
+ " key_column_name = \"channel\",\n",
+ ")\n",
"\n",
- " return messages \\\n",
- " .select(\"user\", \"ts\", \"text\", \"reactions\") \\\n",
- " .collect(window=kt.windows.Since(is_new_conversation), max=100)"
+ "messages.preview(5)"
]
},
{
"cell_type": "markdown",
- "id": "9247233a",
+ "id": "5076d2bf-6830-460b-a9cb-948d8f106edc",
"metadata": {},
- "source": []
+ "source": [
+ "### Build examples"
+ ]
},
{
"cell_type": "code",
- "execution_count": null,
- "id": "fdb2d959-d371-4026-9f8d-4ab26cfbf317",
+ "execution_count": 6,
+ "id": "af7d2a45-eb89-47ce-b471-a39ad8c7bbc7",
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " _time | \n",
+ " _subsort | \n",
+ " _key_hash | \n",
+ " _key | \n",
+ " prompt | \n",
+ " completion | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " | 0 | \n",
+ " 2023-07-25 19:43:13 | \n",
+ " 0 | \n",
+ " 15750806798332339587 | \n",
+ " general | \n",
+ " [{'ts': 1690314133000000000, 'user': 'U05JQJJD... | \n",
+ " [] | \n",
+ "
\n",
+ " \n",
+ " | 1 | \n",
+ " 2023-07-25 19:43:14 | \n",
+ " 1 | \n",
+ " 3094307063304068259 | \n",
+ " random | \n",
+ " [{'ts': 1690314134000000000, 'user': 'U05JQJJD... | \n",
+ " [] | \n",
+ "
\n",
+ " \n",
+ " | 2 | \n",
+ " 2023-07-25 19:45:27 | \n",
+ " 2 | \n",
+ " 2954779196800164886 | \n",
+ " demo | \n",
+ " [{'ts': 1690314267000000000, 'user': 'U05JQJJD... | \n",
+ " [] | \n",
+ "
\n",
+ " \n",
+ " | 3 | \n",
+ " 2023-07-26 08:30:35 | \n",
+ " 3 | \n",
+ " 15750806798332339587 | \n",
+ " general | \n",
+ " [{'ts': 1690314133000000000, 'user': 'U05JQJJD... | \n",
+ " [] | \n",
+ "
\n",
+ " \n",
+ " | 4 | \n",
+ " 2023-07-26 08:30:37 | \n",
+ " 4 | \n",
+ " 15750806798332339587 | \n",
+ " general | \n",
+ " [{'ts': 1690314133000000000, 'user': 'U05JQJJD... | \n",
+ " [] | \n",
+ "
\n",
+ " \n",
+ " | 5 | \n",
+ " 2023-07-26 08:31:10 | \n",
+ " 5 | \n",
+ " 2954779196800164886 | \n",
+ " demo | \n",
+ " [{'ts': 1690314267000000000, 'user': 'U05JQJJD... | \n",
+ " [] | \n",
+ "
\n",
+ " \n",
+ " | 6 | \n",
+ " 2023-07-26 08:31:14 | \n",
+ " 6 | \n",
+ " 3094307063304068259 | \n",
+ " random | \n",
+ " [{'ts': 1690314134000000000, 'user': 'U05JQJJD... | \n",
+ " [] | \n",
+ "
\n",
+ " \n",
+ " | 7 | \n",
+ " 2023-07-26 08:31:40 | \n",
+ " 7 | \n",
+ " 3094307063304068259 | \n",
+ " random | \n",
+ " [{'ts': 1690314134000000000, 'user': 'U05JQJJD... | \n",
+ " [] | \n",
+ "
\n",
+ " \n",
+ " | 8 | \n",
+ " 2023-07-26 08:31:49 | \n",
+ " 8 | \n",
+ " 3094307063304068259 | \n",
+ " random | \n",
+ " [{'ts': 1690314134000000000, 'user': 'U05JQJJD... | \n",
+ " [] | \n",
+ "
\n",
+ " \n",
+ " | 9 | \n",
+ " 2023-07-26 08:31:53 | \n",
+ " 9 | \n",
+ " 3094307063304068259 | \n",
+ " random | \n",
+ " [{'ts': 1690314134000000000, 'user': 'U05JQJJD... | \n",
+ " [] | \n",
+ "
\n",
+ " \n",
+ " | 10 | \n",
+ " 2023-07-26 08:31:57 | \n",
+ " 10 | \n",
+ " 3094307063304068259 | \n",
+ " random | \n",
+ " [{'ts': 1690314134000000000, 'user': 'U05JQJJD... | \n",
+ " [] | \n",
+ "
\n",
+ " \n",
+ " | 11 | \n",
+ " 2023-07-26 09:56:07 | \n",
+ " 11 | \n",
+ " 3094307063304068259 | \n",
+ " random | \n",
+ " [{'ts': 1690314134000000000, 'user': 'U05JQJJD... | \n",
+ " [] | \n",
+ "
\n",
+ " \n",
+ " | 12 | \n",
+ " 2023-07-26 09:56:35 | \n",
+ " 12 | \n",
+ " 15750806798332339587 | \n",
+ " general | \n",
+ " [{'ts': 1690314133000000000, 'user': 'U05JQJJD... | \n",
+ " [] | \n",
+ "
\n",
+ " \n",
+ " | 13 | \n",
+ " 2023-07-26 09:56:52 | \n",
+ " 13 | \n",
+ " 2954779196800164886 | \n",
+ " demo | \n",
+ " [{'ts': 1690314267000000000, 'user': 'U05JQJJD... | \n",
+ " [] | \n",
+ "
\n",
+ " \n",
+ " | 14 | \n",
+ " 2023-07-26 19:38:00 | \n",
+ " 14 | \n",
+ " 2954779196800164886 | \n",
+ " demo | \n",
+ " [{'ts': 1690314267000000000, 'user': 'U05JQJJD... | \n",
+ " [] | \n",
+ "
\n",
+ " \n",
+ " | 15 | \n",
+ " 2023-07-26 19:38:00 | \n",
+ " 15 | \n",
+ " 15750806798332339587 | \n",
+ " general | \n",
+ " [{'ts': 1690314133000000000, 'user': 'U05JQJJD... | \n",
+ " [] | \n",
+ "
\n",
+ " \n",
+ " | 16 | \n",
+ " 2023-07-26 19:38:00 | \n",
+ " 16 | \n",
+ " 3094307063304068259 | \n",
+ " random | \n",
+ " [{'ts': 1690314134000000000, 'user': 'U05JQJJD... | \n",
+ " [] | \n",
+ "
\n",
+ " \n",
+ " | 17 | \n",
+ " 2023-07-26 19:44:07 | \n",
+ " 17 | \n",
+ " 3094307063304068259 | \n",
+ " random | \n",
+ " [{'ts': 1690314134000000000, 'user': 'U05JQJJD... | \n",
+ " [U05JQJJDJ6P] | \n",
+ "
\n",
+ " \n",
+ " | 18 | \n",
+ " 2023-07-26 20:17:30 | \n",
+ " 18 | \n",
+ " 2954779196800164886 | \n",
+ " demo | \n",
+ " [{'ts': 1690314267000000000, 'user': 'U05JQJJD... | \n",
+ " [] | \n",
+ "
\n",
+ " \n",
+ " | 19 | \n",
+ " 2023-07-26 20:17:30 | \n",
+ " 19 | \n",
+ " 15750806798332339587 | \n",
+ " general | \n",
+ " [{'ts': 1690314133000000000, 'user': 'U05JQJJD... | \n",
+ " [] | \n",
+ "
\n",
+ " \n",
+ " | 20 | \n",
+ " 2023-07-26 20:17:30 | \n",
+ " 20 | \n",
+ " 3094307063304068259 | \n",
+ " random | \n",
+ " [{'ts': 1690314134000000000, 'user': 'U05JQJJD... | \n",
+ " [U05JQJJDJ6P] | \n",
+ "
\n",
+ " \n",
+ " | 21 | \n",
+ " 2023-07-26 20:19:25 | \n",
+ " 21 | \n",
+ " 15750806798332339587 | \n",
+ " general | \n",
+ " [{'ts': 1690314133000000000, 'user': 'U05JQJJD... | \n",
+ " [U05JQJJDJ6P] | \n",
+ "
\n",
+ " \n",
+ " | 22 | \n",
+ " 2023-07-26 20:32:04 | \n",
+ " 0 | \n",
+ " 15750806798332339587 | \n",
+ " general | \n",
+ " [{'ts': 1690314133000000000, 'user': 'U05JQJJD... | \n",
+ " [U05JQJJDJ6P] | \n",
+ "
\n",
+ " \n",
+ " | 23 | \n",
+ " 2023-07-26 20:32:19 | \n",
+ " 1 | \n",
+ " 15750806798332339587 | \n",
+ " general | \n",
+ " [{'ts': 1690314133000000000, 'user': 'U05JQJJD... | \n",
+ " [U05JQJJDJ6P] | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " _time _subsort _key_hash _key \\\n",
+ "0 2023-07-25 19:43:13 0 15750806798332339587 general \n",
+ "1 2023-07-25 19:43:14 1 3094307063304068259 random \n",
+ "2 2023-07-25 19:45:27 2 2954779196800164886 demo \n",
+ "3 2023-07-26 08:30:35 3 15750806798332339587 general \n",
+ "4 2023-07-26 08:30:37 4 15750806798332339587 general \n",
+ "5 2023-07-26 08:31:10 5 2954779196800164886 demo \n",
+ "6 2023-07-26 08:31:14 6 3094307063304068259 random \n",
+ "7 2023-07-26 08:31:40 7 3094307063304068259 random \n",
+ "8 2023-07-26 08:31:49 8 3094307063304068259 random \n",
+ "9 2023-07-26 08:31:53 9 3094307063304068259 random \n",
+ "10 2023-07-26 08:31:57 10 3094307063304068259 random \n",
+ "11 2023-07-26 09:56:07 11 3094307063304068259 random \n",
+ "12 2023-07-26 09:56:35 12 15750806798332339587 general \n",
+ "13 2023-07-26 09:56:52 13 2954779196800164886 demo \n",
+ "14 2023-07-26 19:38:00 14 2954779196800164886 demo \n",
+ "15 2023-07-26 19:38:00 15 15750806798332339587 general \n",
+ "16 2023-07-26 19:38:00 16 3094307063304068259 random \n",
+ "17 2023-07-26 19:44:07 17 3094307063304068259 random \n",
+ "18 2023-07-26 20:17:30 18 2954779196800164886 demo \n",
+ "19 2023-07-26 20:17:30 19 15750806798332339587 general \n",
+ "20 2023-07-26 20:17:30 20 3094307063304068259 random \n",
+ "21 2023-07-26 20:19:25 21 15750806798332339587 general \n",
+ "22 2023-07-26 20:32:04 0 15750806798332339587 general \n",
+ "23 2023-07-26 20:32:19 1 15750806798332339587 general \n",
+ "\n",
+ " prompt completion \n",
+ "0 [{'ts': 1690314133000000000, 'user': 'U05JQJJD... [] \n",
+ "1 [{'ts': 1690314134000000000, 'user': 'U05JQJJD... [] \n",
+ "2 [{'ts': 1690314267000000000, 'user': 'U05JQJJD... [] \n",
+ "3 [{'ts': 1690314133000000000, 'user': 'U05JQJJD... [] \n",
+ "4 [{'ts': 1690314133000000000, 'user': 'U05JQJJD... [] \n",
+ "5 [{'ts': 1690314267000000000, 'user': 'U05JQJJD... [] \n",
+ "6 [{'ts': 1690314134000000000, 'user': 'U05JQJJD... [] \n",
+ "7 [{'ts': 1690314134000000000, 'user': 'U05JQJJD... [] \n",
+ "8 [{'ts': 1690314134000000000, 'user': 'U05JQJJD... [] \n",
+ "9 [{'ts': 1690314134000000000, 'user': 'U05JQJJD... [] \n",
+ "10 [{'ts': 1690314134000000000, 'user': 'U05JQJJD... [] \n",
+ "11 [{'ts': 1690314134000000000, 'user': 'U05JQJJD... [] \n",
+ "12 [{'ts': 1690314133000000000, 'user': 'U05JQJJD... [] \n",
+ "13 [{'ts': 1690314267000000000, 'user': 'U05JQJJD... [] \n",
+ "14 [{'ts': 1690314267000000000, 'user': 'U05JQJJD... [] \n",
+ "15 [{'ts': 1690314133000000000, 'user': 'U05JQJJD... [] \n",
+ "16 [{'ts': 1690314134000000000, 'user': 'U05JQJJD... [] \n",
+ "17 [{'ts': 1690314134000000000, 'user': 'U05JQJJD... [U05JQJJDJ6P] \n",
+ "18 [{'ts': 1690314267000000000, 'user': 'U05JQJJD... [] \n",
+ "19 [{'ts': 1690314133000000000, 'user': 'U05JQJJD... [] \n",
+ "20 [{'ts': 1690314134000000000, 'user': 'U05JQJJD... [U05JQJJDJ6P] \n",
+ "21 [{'ts': 1690314133000000000, 'user': 'U05JQJJD... [U05JQJJDJ6P] \n",
+ "22 [{'ts': 1690314133000000000, 'user': 'U05JQJJD... [U05JQJJDJ6P] \n",
+ "23 [{'ts': 1690314133000000000, 'user': 'U05JQJJD... [U05JQJJDJ6P] "
+ ]
+ },
+ "execution_count": 6,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
"source": [
- "def build_examples(messages):\n",
- " duration = datetime.timedelta(minutes=5)\n",
+ "# Group messages by thread (if present) or channel\n",
+ "#messages = messages.with_key(kt.record({\n",
+ "# \"channel\": messages.col(\"channel\"),\n",
+ "# \"thread\": messages.col(\"thread_ts\"),\n",
+ "# }))\n",
"\n",
- " coverstation = build_conversation(messages)\n",
- " shifted_coversation = coverstation.shift_by(duration) # !!!\n",
"\n",
- " reaction_users = coverstation.col(\"reactions\").col(\"name\").collect(kt.windows.Trailing(duration)).flatten() # !!!\n",
- " participating_users = coverstation.col(\"user\").collect(kt.windows.Trailing(duration)) # !!!\n",
- " engaged_users = kt.union(reaction_users, participating_users) # !!!\n",
+ "# Build the input prompt from recent messages\n",
+ "prompts = messages \\\n",
+ " .select(\"user\", \"ts\", \"text\", \"reactions\") \\\n",
+ " .collect(max=20)\n",
"\n",
- " return kt.record({ \"prompt\": shifted_coversation, \"completion\": engaged_users}) \\\n",
- " .filter(shifted_coversation.is_not_null())"
+ "\n",
+ "# Build the completion from users who engage after the prompt\n",
+ "duration = datetime.timedelta(minutes=1)\n",
+ "\n",
+ "shifted_prompts = prompts.shift_by(duration)\n",
+ "\n",
+ "reaction_users = messages.collect(max=100).col(\"reactions\").flatten().col(\"users\").flatten().last()\n",
+ "#reaction_users = messages.collect(kt.Trailing(duration), max=100).col(\"reactions\").flatten().col(\"users\").flatten().last()\n",
+ "#participating_users = prompts.col(\"user\").collect(max=100) #kt.windows.Trailing(duration))\n",
+ "engaged_users = reaction_users #kt.union(reaction_users, participating_users)\n",
+ "\n",
+ "examples = kt.record({\"prompt\": shifted_prompts, \"completion\": engaged_users}) \\\n",
+ " .filter(shifted_prompts.is_not_null())\n",
+ "\n",
+ "\n",
+ "prompts.preview(5)\n",
+ "examples.preview(100)"
]
},
{
"cell_type": "markdown",
- "id": "0035f558-23bd-4b4d-95a0-ed5e8fece673",
+ "id": "fed3b924-e7de-414b-bba5-b119e40921f0",
"metadata": {},
"source": [
- "## Fine-tune the model"
+ "## Fine-tune a model"
]
},
{
- "cell_type": "code",
- "execution_count": null,
- "id": "af7d2a45-eb89-47ce-b471-a39ad8c7bbc7",
+ "cell_type": "markdown",
+ "id": "e78fa9bd-9c40-403d-a7ee-a15620a88418",
"metadata": {},
- "outputs": [],
"source": [
- "import pandas\n",
- "import sparrow_pi.sources as sources\n",
- "\n",
- "messages = kt.sources.Parquet(\"./messages.parquet\", time = \"ts\", entity = \"channel\")\n",
- "messages = messages.with_key(kt.record({ # !!!\n",
- " \"channel\": messages.col(\"channel\"),\n",
- " \"thread\": messages.col(\"thread_ts\"),\n",
- " }))\n",
- "examples = build_examples(messages)\n",
- "\n",
- "examples_df = examples.run().to_pandas()"
+ "### Create training dataset"
]
},
{
@@ -141,22 +672,36 @@
"source": [
"from sklearn import preprocessing\n",
"\n",
+ "# Extract examples from historical data\n",
+ "examples_df = examples.run().to_pandas().drop([\"_time\", \"_subsort\", \"_key_hash\", \"_key\"], axis=1)\n",
+ "\n",
+ "\n",
+ "# Encode user ID labels\n",
"le = preprocessing.LabelEncoder()\n",
"le.fit(examples_df.completion.explode())\n",
"\n",
+ "\n",
"# Format for the OpenAI API\n",
"def format_prompt(prompt):\n",
" return \"start -> \" + \"\\n\\n\".join([f' {msg[\"user\"]} --> {msg[\"text\"]} ' for msg in prompt]) + \"\\n\\n###\\n\\n\"\n",
"examples_df.prompt = examples_df.prompt.apply(format_prompt)\n",
- "\n",
"def format_completion(completion):\n",
- " return \" \" + (\" \".join([le.transform(u) for u in completion]) if len(completion) > 0 else \"nil\") + \" end\"\n",
+ " return \" \" + (\" \".join(le.transform(completion).astype(str)) if len(completion) > 0 else \"nil\") + \" end\"\n",
"examples_df.completion = examples_df.completion.apply(format_completion)\n",
"\n",
+ "\n",
"# Write examples to file\n",
"examples_df.to_json(\"examples.jsonl\", orient='records', lines=True)"
]
},
+ {
+ "cell_type": "markdown",
+ "id": "cfc60311-5ca1-49f3-8e35-4070174e0258",
+ "metadata": {},
+ "source": [
+ "### Upload to OpenAI"
+ ]
+ },
{
"cell_type": "code",
"execution_count": null,
@@ -169,12 +714,23 @@
"from types import SimpleNamespace\n",
"from openai import cli\n",
"\n",
- "# verifiy data format, split for training & validation\n",
+ "# Initialize OpenAI\n",
+ "openai.api_key = getpass.getpass('OpenAI: API Key')\n",
+ "\n",
+ "# Verifiy data format, split for training & validation, upload to OpenAI\n",
"args = SimpleNamespace(file='./examples.jsonl', quiet=True)\n",
"cli.FineTune.prepare_data(args)\n",
"training_id = cli.FineTune._get_or_upload('./examples_prepared_train.jsonl', True)"
]
},
+ {
+ "cell_type": "markdown",
+ "id": "f0808e6e-8239-4487-b22b-14b5a00948c6",
+ "metadata": {},
+ "source": [
+ "### Create the training job"
+ ]
+ },
{
"cell_type": "code",
"execution_count": null,
@@ -193,103 +749,6 @@
")\n",
"print(f'Fine-tuning model with job ID: \"{resp[\"id\"]}\"')"
]
- },
- {
- "cell_type": "markdown",
- "id": "b3e29109-cc00-4bf5-ba23-069e8db1f179",
- "metadata": {
- "jp-MarkdownHeadingCollapsed": true,
- "tags": []
- },
- "source": [
- "## Notify users of conversations they need to know about"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "540afff7-4ebc-427f-8205-1ed145e0c59a",
- "metadata": {
- "tags": []
- },
- "outputs": [],
- "source": [
- "import json, math\n",
- "\n",
- "min_prob_for_response = 0.50\n",
- "\n",
- "# Receive Slack messages in real-time\n",
- "live_messages = kt.sources.ArrowSource(entity_column=\"channel\", time_column=\"ts\")\n",
- "\n",
- "# Receive messages from Slack\n",
- "def handle_message(client, req):\n",
- " # Acknowledge the message back to Slack\n",
- " client.send_socket_mode_response(SocketModeResponse(envelope_id=req.envelope_id))\n",
- " \n",
- " if req.type == \"events_api\" and \"event\" in req.payload:\n",
- " e = req.payload[\"event\"]\n",
- " if \"previous_message\" in e or e[\"type\"] == \"reaction_added\":\n",
- " return\n",
- " # send message events to Kaskada\n",
- " live_messages.add_event(pyarrow.json.read_json(e))\n",
- "\n",
- "slack.socket_mode_request_listeners.append(handle_message)\n",
- "slack.connect()\n",
- "\n",
- "# Handle messages in realtime\n",
- "# A \"conversation\" is a list of messages\n",
- "for conversation in build_conversation(live_messages).start().to_generator():\n",
- " if len(conversation) == 0:\n",
- " continue\n",
- " \n",
- " # Ask the model who should be notified\n",
- " res = openai.Completion.create(\n",
- " model=\"davinci:ft-personal:coversation-users-full-kaskada-2023-08-05-14-25-30\", \n",
- " prompt=format_prompt(conversation),\n",
- " logprobs=5,\n",
- " max_tokens=1,\n",
- " stop=\" end\",\n",
- " temperature=0,\n",
- " )\n",
- "\n",
- " users = []\n",
- " logprobs = res[\"choices\"][0][\"logprobs\"][\"top_logprobs\"][0]\n",
- " for user in logprobs:\n",
- " if math.exp(logprobs[user]) > min_prob_for_response:\n",
- " user = users.strip()\n",
- " # if users include `nil`, stop processing\n",
- " if user == \"nil\":\n",
- " users = []\n",
- " break\n",
- " users.append(user)\n",
- "\n",
- " # alert on most recent message in conversation\n",
- " msg = conversation.pop()\n",
- " \n",
- " # Send notification to users\n",
- " for user in users:\n",
- " user_id = le.inverse_transform(user)\n",
- "\n",
- " # get user channel for slackbot\n",
- " app = slack.web_client.users_conversations(\n",
- " types=\"im\",\n",
- " user=user_id,\n",
- " )\n",
- " \n",
- " # confirm user has slackbot installed\n",
- " if len(app[\"channels\"]) == 0:\n",
- " continue\n",
- "\n",
- " link = slack.web_client.chat_getPermalink(\n",
- " channel=msg[\"channel\"],\n",
- " message_ts=msg[\"ts\"],\n",
- " )[\"permalink\"]\n",
- " \n",
- " slack.web_client.chat_postMessage(\n",
- " channel=app[\"channels\"][0][\"id\"],\n",
- " text=f'You may be interested in this converstation: <{link}|{msg[\"text\"]}>'\n",
- " )"
- ]
}
],
"metadata": {
@@ -308,7 +767,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.10.9"
+ "version": "3.11.3"
}
},
"nbformat": 4,
diff --git a/examples/slackbot/README.md b/examples/slackbot/README.md
index ee26b707b..08c68b778 100644
--- a/examples/slackbot/README.md
+++ b/examples/slackbot/README.md
@@ -37,15 +37,5 @@ copy(
union_by_name=true)
)
) to 'messages.parquet' (FORMAT PARQUET)
-;
-
- select
- from (
- select * from read_json_auto('data/iloveai-initial-export/*/*.json',
- format='array',
- filename=true,
- union_by_name=true)
- )
- limit 10
;
```
\ No newline at end of file
diff --git a/examples/slackbot/messages.parquet b/examples/slackbot/messages.parquet
index a3ad8a5ac..170705f99 100644
Binary files a/examples/slackbot/messages.parquet and b/examples/slackbot/messages.parquet differ
diff --git a/examples/slackbot/run.py b/examples/slackbot/run.py
new file mode 100644
index 000000000..a2743cc21
--- /dev/null
+++ b/examples/slackbot/run.py
@@ -0,0 +1,128 @@
+import json, math, openai, os
+from slack_sdk.web import WebClient
+from slack_sdk.socket_mode import SocketModeClient
+from slack_sdk.socket_mode.response import SocketModeResponse
+
+output_map = {}
+
+with open('./user_output_map.json', 'r') as file:
+ output_map = json.load(file)
+
+print(f'Loaded output map: {output_map}')
+
+
+# Initialize OpenAI
+openai.api_key = os.environ.get("OPEN_AI_KEY")
+
+# Initialize Slack
+slack = SocketModeClient(
+ app_token=os.environ.get("SLACK_APP_TOKEN"),
+ web_client=WebClient(token=os.environ.get("SLACK_BOT_TOKEN"))
+)
+
+min_prob_for_response = 0.50
+
+# Format for the OpenAI API
+def format_prompt(prompt):
+ return "start -> " + "\n\n".join([f' {msg["user"]} --> {msg["text"]} ' for msg in prompt]) + "\n\n###\n\n"
+
+def handle_conversation(conversation):
+ if len(conversation) == 0:
+ return
+
+ print(f'Starting prediction on conversation: {conversation[0]["text"]}')
+
+ # Ask the model who should be notified
+ res = openai.Completion.create(
+ model="davinci:ft-personal:coversation-users-full-kaskada-2023-08-05-14-25-30",
+ prompt=format_prompt(conversation),
+ max_tokens=1,
+ stop=" end",
+ temperature=0,
+ logprobs=5,
+ )
+
+ users = []
+ logprobs = res["choices"][0]["logprobs"]["top_logprobs"][0]
+
+ print(f'Recieved log probs: {logprobs}')
+ for user in logprobs:
+ if math.exp(logprobs[user]) > min_prob_for_response:
+ # if `nil` user is an option, stop processing
+ user = user.strip()
+ if user == "nil":
+ users = []
+ print('Found nil, stopping.')
+ break
+ users.append(user)
+
+ print(f'Found users to alert: {users}')
+ # alert on most recent message in conversation
+ msg = conversation.pop()
+
+ # Send notification to users
+ for user_num in users:
+ if user_num not in output_map:
+ print(f'User: {user_num} not in output_map, stopping.')
+ else:
+ user_id = output_map[user_num]
+
+ print(f'Found user {user_num} in output map: {user_id}')
+
+ link = slack.web_client.chat_getPermalink(
+ channel=msg["channel"],
+ message_ts=msg["ts"],
+ )["permalink"]
+
+ print(f'Got message link: {link}')
+
+ res = slack.web_client.users_conversations(
+ types="im",
+ user=user_id,
+ )
+ if len(res["channels"]) == 0:
+ print(f'User: {user} hasn\'t installed the slackbot yet')
+ else:
+ app_channel = res["channels"][0]["id"]
+ print(f'Got user\'s bot channel id: {app_channel}')
+
+ slack.web_client.chat_postMessage(
+ channel=app_channel,
+ text=f'You may be interested in this converstation: <{link}|{msg["text"]}>'
+ )
+
+ print(f'Posted alert message')
+
+# Receive Slack messages in real-time
+#live_messages = kt.sources.read_stream(entity_column="channel", time_column="ts")
+
+# Receive messages from Slack
+def handle_message(client, req):
+ # Acknowledge the message back to Slack
+ client.send_socket_mode_response(SocketModeResponse(envelope_id=req.envelope_id))
+
+ # Deliver the message to Kaskada
+ #live_messages.add_event(pyarrow.json.read_json(req.payload))
+
+ if req.type == "events_api" and "event" in req.payload:
+ e = req.payload["event"]
+
+ # ignore message edit, delete, reaction events
+ if "previous_message" in e or e["type"] == "reaction_added":
+ return
+
+ # make single-message conversations for now
+ handle_conversation([e])
+
+
+# Handle messages in realtime
+# A "conversation" is a list of messages
+#for conversation in build_conversation(live_messages).start().to_generator():
+
+
+slack.socket_mode_request_listeners.append(handle_message)
+slack.connect()
+
+# Just not to stop this process
+from threading import Event
+Event().wait()
diff --git a/examples/slackbot/slackbot.py b/examples/slackbot/slackbot.py
index 094b6d68d..ba286c6f2 100644
--- a/examples/slackbot/slackbot.py
+++ b/examples/slackbot/slackbot.py
@@ -1,123 +1,100 @@
-import json, math, openai, os, pyarrow
-from slack_sdk.web import WebClient
-from slack_sdk.socket_mode import SocketModeClient
+import json, math, datetime, openai, os, pyarrow, pandas, asyncio
+from slack_sdk.web.async_client import AsyncWebClient
+from slack_sdk.socket_mode.aiohttp import SocketModeClient
from slack_sdk.socket_mode.response import SocketModeResponse
-import sparrow_pi as kt
+import sparrow_py as kt
-def build_conversation(messages):
- message_time = messages.col("ts")
- last_message_time = message_time.lag(1) # !!!
- is_new_conversation = message_time.seconds_since(last_message_time) > 10 * 60
-
- return messages \
- .select("user", "ts", "text", "reactions") \
- .collect(window=kt.windows.Since(is_new_conversation), max=100)
-
-def build_examples(messages):
- duration = kt.minutes(5) # !!!
-
- coverstation = build_conversation(messages)
- shifted_coversation = coverstation.shift_by(duration) # !!!
-
- reaction_users = coverstation.col("reactions").col("name").collect(kt.windows.Trailing(duration)).flatten() # !!!
- participating_users = coverstation.col("user").collect(kt.windows.Trailing(duration)) # !!!
- engaged_users = kt.union(reaction_users, participating_users) # !!!
-
- return kt.record({ "prompt": shifted_coversation, "completion": engaged_users}) \
- .filter(shifted_coversation.is_not_null())
-
-def format_prompt(prompt):
- return "start -> " + "\n\n".join([f' {msg["user"]} --> {msg["text"]} ' for msg in prompt]) + "\n\n###\n\n"
-
-def main():
+async def main():
+ start = datetime.datetime.now()
+
+ # Load user label map
output_map = {}
-
with open('./user_output_map.json', 'r') as file:
output_map = json.load(file)
- print(f'Loaded output map: {output_map}')
-
- # Initialize Kaskada with a local execution context.
+ # Initialize clients
kt.init_session()
-
- # Initialize OpenAI
openai.api_key = os.environ.get("OPEN_AI_KEY")
-
- # Initialize Slack
slack = SocketModeClient(
app_token=os.environ.get("SLACK_APP_TOKEN"),
- web_client=WebClient(token=os.environ.get("SLACK_BOT_TOKEN"))
+ web_client=AsyncWebClient(token=os.environ.get("SLACK_BOT_TOKEN"))
)
- min_prob_for_response = 0.50
- # Receive Slack messages in real-time
- live_messages = kt.sources.read_stream(entity_column="channel", time_column="ts")
+
+ # Backfill state with historical data
+ historical_data = pandas.read_parquet("./messages.parquet")[:1]
+ schema = pyarrow.Schema.from_pandas(historical_data)
+ messages = kt.sources.ArrowSource(
+ data = historical_data,
+ time_column_name = "ts",
+ key_column_name = "channel",
+ )
+
- # Receive messages from Slack
- def handle_message(client, req):
+
+ # Receive Slack messages in real-time
+ async def handle_message(client, req):
# Acknowledge the message back to Slack
- client.send_socket_mode_response(SocketModeResponse(envelope_id=req.envelope_id))
+ await client.send_socket_mode_response(SocketModeResponse(envelope_id=req.envelope_id))
if req.type == "events_api" and "event" in req.payload:
- e = req.payload["event"]
-
- print(f'Received event from slack websocket: {e}')
-
# ignore message edit, delete, reaction events
- if "previous_message" in e or e["type"] == "reaction_added":
+ if "previous_message" in req.payload["event"] or req.payload["event"]["type"] == "reaction_added":
return
- print(f'Sending message event to kaskada: {e}')
-
- # Deliver the message to Kaskada
- live_messages.add_event(pyarrow.json.read_json(e))
+ req.payload["event"]["ts"] = datetime.datetime.fromtimestamp(float(req.payload["event"]["ts"]))
+ del req.payload["event"]["team"]
+ data = pyarrow.RecordBatch.from_pylist([req.payload["event"]], schema=schema)
+ messages.add_data(data)
slack.socket_mode_request_listeners.append(handle_message)
- slack.connect()
+ await slack.connect()
+
+
+
+ # Compute conversations from individual messages
+ conversations = messages.with_key(kt.record({
+ "channel": messages.col("channel"),
+ "thread": messages.col("thread_ts"),
+ })) \
+ .select("user", "ts", "text", "reactions") \
+ .collect(max=3)
+
- # Handle messages in realtime
- # A "conversation" is a list of messages
- for conversation in build_conversation(live_messages).start().to_generator():
- if len(conversation) == 0:
+
+ # Handle each conversation as it occurs
+ async for row in conversations.run(materialize=True).iter_rows_async():
+ conversation = row[" result"]
+ if len(conversation) == 0 or row["_time"] < start:
continue
print(f'Starting completion on conversation with first message text: {conversation[0]["text"]}')
- prompt = format_prompt(conversation)
-
- print(f'Using prompt: {prompt}')
-
# Ask the model who should be notified
res = openai.Completion.create(
model="davinci:ft-personal:coversation-users-full-kaskada-2023-08-05-14-25-30",
- prompt=prompt,
+ prompt="start -> " + "\n\n".join([f' {msg["user"]} --> {msg["text"]} ' for msg in conversation]) + "\n\n###\n\n",
logprobs=5,
max_tokens=1,
stop=" end",
- temperature=0,
+ temperature=1,
)
- print(f'Received completion response: {res}')
-
+ msg = conversation.pop(0)
users = []
logprobs = res["choices"][0]["logprobs"]["top_logprobs"][0]
-
- print(f'Found logprobs: {logprobs}')
+ print(f"Predicted interest logprobs: {logprobs}")
+ print(f"Notifying users: {users}")
for user in logprobs:
- if math.exp(logprobs[user]) > min_prob_for_response:
- user = users.strip()
+ if math.exp(logprobs[user]) > 0.30:
+ user = user.strip()
# if users include `nil`, stop processing
if user == "nil":
users = []
break
users.append(user)
-
- print(f'Found users to alert: {users}')
-
- # alert on most recent message in conversation
- msg = conversation.pop()
-
+
# Send notification to users
for user_num in users:
if user_num not in output_map:
@@ -153,5 +130,7 @@ def handle_message(client, req):
print(f'Posted alert message')
+ print("Done")
+
if __name__ == "__main__":
- main()
\ No newline at end of file
+ asyncio.run(main())
\ No newline at end of file