Skip to content

Commit c871915

Browse files
draganjovanovichagiykolliestanley
authored
Basic implementation of an plugin system for OA (#2765)
# the plugins Hi, this is my first PR here, but I was somewhat active on other fronts of OA development. This pr will bring some basic plugin functionality to the Open Assistant, and as discussed with @yk @andreaskoepf there are quite a few directions for something like this to be integrated with OA, but this should serve a purpose as some initial proof-of-concept and exploratory feature for 3rd party integrations with OA. I also included a small **calculator** plugin as a possible candidate for the OA internal plugin, which would be like the default one, for people to try out and also as an example, of how one could implement own plugins. If we decide to include this plugin, there should be added a deployment/hosting mechanism for it. I will push a separate branch in the next couple of days, that will serve as an alternative to the approach, so we can A/B test it along with the new models (SFT-s/RLHF-s) I also tried to comment on every weird quirk or decision in code, so one could easily understand, and change it, but there are quite a few places, where a simple new line char or like " char in specific strings, could affect LLM performance in the plugin usage scenario. Will also try to push some documentation regarding plugin development, but there are already some useful comments in **calculator** plugin on what should be paid attention to. Here are some of the current UI changes introduced with this PR. <details> <summary>Plugin chooser component</summary> <img width="854" alt="Screenshot 2023-04-20 at 00 55 38" src="https://user-images.githubusercontent.com/13547364/233217078-d2e4e28f-36eb-451e-a655-1679188aed52.png"> </details> <details> <summary>Plugin execution details component</summary> <img width="824" alt="Screenshot 2023-04-19 at 21 40 38" src="https://user-images.githubusercontent.com/13547364/233216884-e69bcf9c-707f-43de-a52d-41db5d92c504.png"> <img width="744" alt="Screenshot 2023-04-19 at 21 40 56" src="https://user-images.githubusercontent.com/13547364/233217161-c114f5b9-881f-4476-a2b1-459179a9353e.png"> <img width="545" alt="Screenshot 2023-04-19 at 21 30 18" src="https://user-images.githubusercontent.com/13547364/233217187-17fb87e5-e4be-43e4-96ac-7cdd84223147.png"> </details> <details open> <summary>Plugin assisted answer</summary> <img width="837" alt="Screenshot 2023-04-19 at 21 29 52" src="https://user-images.githubusercontent.com/13547364/233217260-4986f456-efa5-47a5-aabc-926a8a5a9a2f.png"> <img width="943" alt="Screenshot 2023-04-21 at 18 28 45" src="https://user-images.githubusercontent.com/13547364/233687877-0d0f9ffb-b16a-48de-96ad-e4c3a02f4c66.png"> </details> <details> <summary>Verified plugin usage UI look</summary> <img width="864" alt="Screenshot 2023-04-20 at 15 08 36" src="https://user-images.githubusercontent.com/13547364/233376402-52ed5a3d-631a-4350-9130-61548a8d7b02.png"> </details> <details> <summary>Some plugin usage examples</summary> <img width="1048" alt="Screenshot 2023-04-18 at 01 57 33" src="https://user-images.githubusercontent.com/13547364/233217685-79b262bd-81fd-4641-9ad9-110e8b689e42.png"> <img width="993" alt="Screenshot 2023-04-17 at 23 17 35" src="https://user-images.githubusercontent.com/13547364/233217687-561773a1-b16a-49f5-bdbc-f30b46bed33d.png"> </details> <details open> <summary>Mixed usage example where model chooses not to use plugin on its own</summary> <img width="690" alt="Screenshot 2023-04-20 at 21 31 46" src="https://user-images.githubusercontent.com/13547364/233469420-25c72893-7c7a-426c-9f4b-ce9144d643ae.png"> </details> --------- Co-authored-by: agi <you@example.com> Co-authored-by: Yannic Kilcher <yk@users.noreply.github.com> Co-authored-by: Oliver Stanley <olivergestanley@gmail.com>
1 parent 065186a commit c871915

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+1970
-43
lines changed

inference/full-dev-setup.sh

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,12 @@ else
1313
INFERENCE_TAG=latest
1414
fi
1515

16+
POSTGRES_PORT=${POSTGRES_PORT:-5432}
17+
1618
# Creates a tmux window with splits for the individual services
1719

1820
tmux new-session -d -s "inference-dev-setup"
19-
tmux send-keys "docker run --rm -it -p 5732:5432 -e POSTGRES_PASSWORD=postgres --name postgres postgres" C-m
21+
tmux send-keys "docker run --rm -it -p $POSTGRES_PORT:5432 -e POSTGRES_PASSWORD=postgres --name postgres postgres" C-m
2022
tmux split-window -h
2123
tmux send-keys "docker run --rm -it -p 6779:6379 --name redis redis" C-m
2224

@@ -30,7 +32,7 @@ fi
3032

3133
tmux split-window -h
3234
tmux send-keys "cd server" C-m
33-
tmux send-keys "LOGURU_LEVEL=$LOGLEVEL POSTGRES_PORT=5732 REDIS_PORT=6779 DEBUG_API_KEYS='0000,0001' ALLOW_DEBUG_AUTH=True TRUSTED_CLIENT_KEYS=6969 uvicorn main:app" C-m
35+
tmux send-keys "LOGURU_LEVEL=$LOGLEVEL POSTGRES_PORT=$POSTGRES_PORT REDIS_PORT=6779 DEBUG_API_KEYS='0000,0001' ALLOW_DEBUG_AUTH=True TRUSTED_CLIENT_KEYS=6969 uvicorn main:app" C-m
3436
tmux split-window -h
3537
tmux send-keys "cd text-client" C-m
3638
tmux send-keys "sleep 5" C-m
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
"""added used plugin to message
2+
3+
Revision ID: 5b4211625a9f
4+
Revises: ea19bbc743f9
5+
Create Date: 2023-05-01 22:53:16.297495
6+
7+
"""
8+
import sqlalchemy as sa
9+
from alembic import op
10+
from sqlalchemy.dialects import postgresql
11+
12+
# revision identifiers, used by Alembic.
13+
revision = "5b4211625a9f"
14+
down_revision = "ea19bbc743f9"
15+
branch_labels = None
16+
depends_on = None
17+
18+
19+
def upgrade() -> None:
20+
# ### commands auto generated by Alembic - please adjust! ###
21+
op.add_column("message", sa.Column("used_plugin", postgresql.JSONB(astext_type=sa.Text()), nullable=True))
22+
# ### end Alembic commands ###
23+
24+
25+
def downgrade() -> None:
26+
# ### commands auto generated by Alembic - please adjust! ###
27+
op.drop_column("message", "used_plugin")
28+
# ### end Alembic commands ###

inference/server/oasst_inference_server/chat_repository.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,12 +88,13 @@ async def abort_work(self, message_id: str, reason: str) -> models.DbMessage:
8888
await self.session.refresh(message)
8989
return message
9090

91-
async def complete_work(self, message_id: str, content: str) -> models.DbMessage:
91+
async def complete_work(self, message_id: str, content: str, used_plugin: inference.PluginUsed) -> models.DbMessage:
9292
logger.debug(f"Completing work on message {message_id}")
9393
message = await self.get_assistant_message_by_id(message_id)
9494
message.state = inference.MessageState.complete
9595
message.work_end_at = datetime.datetime.utcnow()
9696
message.content = content
97+
message.used_plugin = used_plugin
9798
await self.session.commit()
9899
logger.debug(f"Completed work on message {message_id}")
99100
await self.session.refresh(message)

inference/server/oasst_inference_server/database.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ def custom_json_deserializer(s):
4343
return chat_schema.CreateMessageRequest.parse_obj(d)
4444
case "WorkRequest":
4545
return inference.WorkRequest.parse_obj(d)
46+
case "PluginUsed":
47+
return inference.PluginUsed.parse_obj(d)
4648
case None:
4749
return d
4850
case _:

inference/server/oasst_inference_server/models/chat.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ class DbMessage(SQLModel, table=True):
2828
safety_label: str | None = Field(None)
2929
safety_rots: str | None = Field(None)
3030

31+
used_plugin: inference.PluginUsed | None = Field(None, sa_column=sa.Column(pg.JSONB))
32+
3133
state: inference.MessageState = Field(inference.MessageState.manual)
3234
work_parameters: inference.WorkParameters = Field(None, sa_column=sa.Column(pg.JSONB))
3335
work_begin_at: datetime.datetime | None = Field(None)
@@ -68,6 +70,7 @@ def to_read(self) -> inference.MessageRead:
6870
safety_level=self.safety_level,
6971
safety_label=self.safety_label,
7072
safety_rots=self.safety_rots,
73+
used_plugin=self.used_plugin,
7174
)
7275

7376

inference/server/oasst_inference_server/routes/chats.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ async def create_assistant_message(
140140
work_parameters = inference.WorkParameters(
141141
model_config=model_config,
142142
sampling_parameters=request.sampling_parameters,
143+
plugins=request.plugins,
143144
)
144145
assistant_message = await ucr.initiate_assistant_message(
145146
parent_id=request.parent_id,

inference/server/oasst_inference_server/routes/configs.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,19 @@
1+
import asyncio
2+
3+
import aiohttp
14
import fastapi
25
import pydantic
6+
import yaml
7+
from aiohttp.client_exceptions import ClientConnectorError, ServerTimeoutError
8+
from fastapi import HTTPException
9+
from loguru import logger
310
from oasst_inference_server.settings import settings
411
from oasst_shared import model_configs
512
from oasst_shared.schemas import inference
613

14+
# NOTE: Populate this with plugins that we will provide out of the box
15+
OA_PLUGINS = []
16+
717
router = fastapi.APIRouter(
818
prefix="/configs",
919
tags=["configs"],
@@ -63,6 +73,16 @@ class ModelConfigInfo(pydantic.BaseModel):
6373
repetition_penalty=1.2,
6474
),
6575
),
76+
ParameterConfig(
77+
name="k50-Plugins",
78+
description="Top-k sampling with k=50 and temperature=0.35",
79+
sampling_parameters=inference.SamplingParameters(
80+
max_new_tokens=1024,
81+
temperature=0.35,
82+
top_k=50,
83+
repetition_penalty=(1 / 0.90),
84+
),
85+
),
6686
ParameterConfig(
6787
name="nucleus9",
6888
description="Nucleus sampling with p=0.9",
@@ -93,6 +113,44 @@ class ModelConfigInfo(pydantic.BaseModel):
93113
]
94114

95115

116+
async def fetch_plugin(url: str, retries: int = 3, timeout: float = 5.0) -> inference.PluginConfig:
117+
async with aiohttp.ClientSession() as session:
118+
for attempt in range(retries):
119+
try:
120+
async with session.get(url, timeout=timeout) as response:
121+
content_type = response.headers.get("Content-Type")
122+
123+
if response.status == 200:
124+
if "application/json" in content_type or url.endswith(".json"):
125+
config = await response.json()
126+
elif (
127+
"application/yaml" in content_type
128+
or "application/x-yaml" in content_type
129+
or url.endswith(".yaml")
130+
or url.endswith(".yml")
131+
):
132+
config = yaml.safe_load(await response.text())
133+
else:
134+
raise HTTPException(
135+
status_code=400,
136+
detail=f"Unsupported content type: {content_type}. Only JSON and YAML are supported.",
137+
)
138+
139+
return inference.PluginConfig(**config)
140+
elif response.status == 404:
141+
raise HTTPException(status_code=404, detail="Plugin not found")
142+
else:
143+
raise HTTPException(status_code=response.status, detail="Unexpected status code")
144+
except (ClientConnectorError, ServerTimeoutError) as e:
145+
if attempt == retries - 1: # last attempt
146+
raise HTTPException(status_code=500, detail=f"Request failed after {retries} retries: {e}")
147+
await asyncio.sleep(2**attempt) # exponential backoff
148+
149+
except aiohttp.ClientError as e:
150+
raise HTTPException(status_code=500, detail=f"Request failed: {e}")
151+
raise HTTPException(status_code=500, detail="Failed to fetch plugin")
152+
153+
96154
@router.get("/model_configs")
97155
async def get_model_configs() -> list[ModelConfigInfo]:
98156
return [
@@ -103,3 +161,36 @@ async def get_model_configs() -> list[ModelConfigInfo]:
103161
for model_config_name in model_configs.MODEL_CONFIGS
104162
if (settings.allowed_model_config_names == "*" or model_config_name in settings.allowed_model_config_names_list)
105163
]
164+
165+
166+
@router.post("/plugin_config")
167+
async def get_plugin_config(plugin: inference.PluginEntry) -> inference.PluginEntry:
168+
try:
169+
plugin_config = await fetch_plugin(plugin.url)
170+
except HTTPException as e:
171+
logger.warning(f"Failed to fetch plugin config from {plugin.url}: {e.detail}")
172+
raise fastapi.HTTPException(status_code=e.status_code, detail=e.detail)
173+
174+
return inference.PluginEntry(url=plugin.url, enabled=plugin.enabled, plugin_config=plugin_config)
175+
176+
177+
@router.get("/builtin_plugins")
178+
async def get_builtin_plugins() -> list[inference.PluginEntry]:
179+
plugins = []
180+
181+
for plugin in OA_PLUGINS:
182+
try:
183+
plugin_config = await fetch_plugin(plugin.url)
184+
except HTTPException as e:
185+
logger.warning(f"Failed to fetch plugin config from {plugin.url}: {e.detail}")
186+
continue
187+
188+
final_plugin: inference.PluginEntry = inference.PluginEntry(
189+
url=plugin.url,
190+
enabled=plugin.enabled,
191+
trusted=plugin.trusted,
192+
plugin_config=plugin_config,
193+
)
194+
plugins.append(final_plugin)
195+
196+
return plugins

inference/server/oasst_inference_server/routes/workers.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,7 @@ async def handle_generated_text_response(
350350
message = await cr.complete_work(
351351
message_id=message_id,
352352
content=response.text,
353+
used_plugin=response.used_plugin,
353354
)
354355
logger.info(f"Completed work for {message_id=}")
355356
message_packet = inference.InternalFinishedMessageResponse(

inference/server/oasst_inference_server/schemas/chat.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ class CreateAssistantMessageRequest(pydantic.BaseModel):
1414
parent_id: str
1515
model_config_name: str
1616
sampling_parameters: inference.SamplingParameters = pydantic.Field(default_factory=inference.SamplingParameters)
17+
plugins: list[inference.PluginEntry] = pydantic.Field(default_factory=list[inference.PluginEntry])
18+
used_plugin: inference.PluginUsed | None = None
1719

1820

1921
class PendingResponseEvent(pydantic.BaseModel):

inference/worker/PLUGINS.md

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
# Plugin system for OA
2+
3+
This is a basic implementation of support for external augmentation and
4+
OpenAI/ChatGPT plugins into the Open-Assistant. In the current state, this is
5+
more of a proof-of-concept and should be considered to be used behind some
6+
experimental flag.
7+
8+
## Architecture
9+
10+
There is now some kind of middleware between work.py(worker) and the final
11+
prompt that is passed to the inference server for generation and streaming. That
12+
middleware is responsible for checking if there is an enabled plugin in the
13+
userland/UI and if so, it will take over the job of creating curated pre-prompts
14+
for plugin usage, as well as generating subsequent calls to LLM(inner
15+
monologues) in order to generate the final externally **augmented** prompt, that
16+
will be passed back to the worker and next to the inference, for final LLM
17+
generation/streaming tokens to the frontend.
18+
19+
## Plugins
20+
21+
Plugins are in essence just pretty wrappers around some kind of API-s and serve
22+
a purpose to help LLM utilize it more precisely and reliably, so they can be
23+
quite useful and powerful augmentation tools for Open-Assistant. Two main parts
24+
of a plugin are the ai-plugin.json file, which is just the main descriptor of a
25+
plugin, and the second part is OpenAPI specification of the plugin API-s.
26+
27+
Here is OpenAI plugins
28+
[specification](https://platform.openai.com/docs/plugins/getting-started) that
29+
is currently partially supported with this system.
30+
31+
For now, only non-authentication-based and only (**GET** request) plugins are
32+
supported. Some of them are:
33+
34+
- https://www.klarna.com/.well-known/ai-plugin.json
35+
- https://www.joinmilo.com/.well-known/ai-plugin.json
36+
37+
Adding support for all other request types would be quite tricky with the
38+
current approach. It would be best to drop current “mansplaining” of the API to
39+
LLM and just show it complete json/yaml content. But unfortunately for that to
40+
be reliable and to work as close as current approach we would need larger
41+
context size and a bit more capable models.
42+
43+
And quite a few of them can be found on this website
44+
[plugin "store" wellknown.ai](https://www.wellknown.ai/)
45+
46+
One of the ideas of the plugin system is that we can have some internal OA
47+
plugins, which will be like out-of-the-box plugins, and there could be endless
48+
third-party community-developed plugins as well.
49+
50+
### Notes regarding the reliability and performance and the limitations of the plugin system
51+
52+
Performance can vary a lot depending on the models and plugins used. Some of
53+
them work better some worse, but that aspect should improve as we get better and
54+
better models. One of the biggest limitations at the moment is context size and
55+
instruction following capabilities. And that is combated with some prompt
56+
tricks, truncations of the plugin OpenAPI descriptions and dynamically
57+
including/excluding parts of the prompts in the internal processing of the
58+
subsequent generations of intermediate texts (inner monologues). More of the
59+
limitations and possible alternatives are explained in code comments.
60+
61+
The current approach is somewhat hybrid I would say, and relies on the zero-shot
62+
capabilities of a model. There will be one more branch with the plugin system
63+
that will be a bit different approach than this one as it will be utilizing
64+
other smaller embedding transformer models and vector stores, so we can do A/B
65+
testing of the system alongside new OA model releases.
66+
67+
## Relevant files for the inference side of the plugin system
68+
69+
- chat_chain.py
70+
- chat*chain_utils.py *(tweaking tools/plugin description string generation can
71+
help for some models)\_
72+
- chat*chain_prompts.py *(tweaking prompts can help also)\_

0 commit comments

Comments
 (0)