Skip to content

Commit cf06697

Browse files
feat: Zero-sahot NLU (#189)
* Add zero shot NLU using LLMs * add docker-compose for ollama * switch between default NLU and LLM pipelines * add synonym replacer NLU component * refactor docker-compose * update docs
1 parent c23c561 commit cf06697

30 files changed

+663
-179
lines changed

README.md

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ You don’t need to be an expert at artificial intelligence to create an awesome
1818
- Spacy Word Embeddings
1919
- Intent Recognition (ML)
2020
- Entity Extraction (ML)
21-
- One shot NLU using Large Language Models (Coming Soon)
21+
- Zero shot NLU using Large Language Models (LLMs)
2222
- Persistent Memory & Context Management
2323
- API request fulfilment
2424
- Channel Integrations
@@ -54,7 +54,7 @@ You don’t need to be an expert at artificial intelligence to create an awesome
5454
docker-compose up -d
5555
```
5656

57-
Open http://localhost:3000/
57+
Open http://localhost:8080/
5858

5959
### Using Helm
6060

@@ -75,8 +75,6 @@ Want to contribute? Check out our [contribution guidelines](CONTRIBUTING.md).
7575

7676
### Tutorial
7777

78-
Checkout this basic tutorial on youtube,
79-
80-
[![Coming Soon](https://www.wpcc.edu/wp-content/uploads/2021/04/YouTube-Stream-Coming-Soon.jpg)](https://www.youtube.com/watch?v=S1Fj7WinaBA)
78+
Check out our [tutorial](docs/01-getting-started.md) to get started.
8179

8280
<hr></hr>

app/admin/bots/routes.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ async def set_config(name: str, config: Dict[str, Any]):
1313
"""
1414
Update bot config
1515
"""
16-
await store.update_config(name, config)
16+
await store.update_nlu_config(name, config)
1717
return {"message": "Config updated successfully"}
1818

1919

@@ -22,7 +22,7 @@ async def get_config(name: str):
2222
"""
2323
Get bot config
2424
"""
25-
return await store.get_config(name)
25+
return await store.get_nlu_config(name)
2626

2727

2828
@router.get("/{name}/export")

app/admin/bots/schemas.py

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,40 @@
1-
from pydantic import BaseModel, Field, ConfigDict
2-
from typing import Dict, Any
1+
from pydantic import BaseModel, Field
2+
from typing import Optional
33
from app.database import ObjectIdField
4+
from datetime import datetime
5+
6+
7+
class TraditionalNLUSettings(BaseModel):
8+
"""Settings for traditional ML-based NLU pipeline"""
9+
10+
intent_detection_threshold: float = 0.75
11+
entity_detection_threshold: float = 0.65
12+
use_spacy: bool = True
13+
14+
15+
class LLMSettings(BaseModel):
16+
"""Settings for LLM-based NLU pipeline"""
17+
18+
base_url: str = "http://127.0.0.1:11434/v1"
19+
api_key: str = "ollama"
20+
model_name: str = "llama2:13b-chat"
21+
max_tokens: int = 4096
22+
temperature: float = 0.7
23+
24+
25+
class NLUConfiguration(BaseModel):
26+
"""Configuration for Natural Language Understanding"""
27+
28+
pipeline_type: str = "traditional" # Either 'traditional' or 'llm'
29+
traditional_settings: TraditionalNLUSettings = TraditionalNLUSettings()
30+
llm_settings: LLMSettings = LLMSettings()
431

532

633
class Bot(BaseModel):
734
"""Base schema for bot"""
835

936
id: ObjectIdField = Field(validation_alias="_id", default=None)
1037
name: str
11-
config: Dict[str, Any] = {}
12-
13-
model_config = ConfigDict(arbitrary_types_allowed=True)
38+
nlu_config: NLUConfiguration = NLUConfiguration()
39+
created_at: Optional[datetime] = None
40+
updated_at: Optional[datetime] = None

app/admin/bots/store.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,42 @@
11
from typing import Dict
2-
from app.admin.bots.schemas import Bot
2+
from app.admin.bots.schemas import Bot, NLUConfiguration
33
from app.admin.entities.store import list_entities, bulk_import_entities
44
from app.admin.intents.store import list_intents, bulk_import_intents
55
from app.database import database
6+
from datetime import datetime
67

78
bot_collection = database.get_collection("bot")
89

910

10-
async def add_bot(data: dict):
11-
await bot_collection.insert_one(data)
11+
async def ensure_default_bot():
12+
# Check if the default bot exists
13+
default_bot = await bot_collection.find_one({"name": "default"})
14+
if default_bot is None:
15+
# Create the default bot
16+
default_bot_data = Bot(name="default")
17+
default_bot_data.created_at = datetime.utcnow()
18+
default_bot_data.updated_at = datetime.utcnow()
19+
await bot_collection.insert_one(
20+
default_bot_data.model_dump(exclude={"id": True})
21+
)
22+
return default_bot_data
23+
return Bot.model_validate(default_bot)
1224

1325

1426
async def get_bot(name: str) -> Bot:
1527
bot = await bot_collection.find_one({"name": name})
1628
return Bot.model_validate(bot)
1729

1830

19-
async def get_config(name: str) -> Dict:
31+
async def get_nlu_config(name: str) -> NLUConfiguration:
2032
bot = await get_bot(name)
21-
return bot.config
33+
return bot.nlu_config
2234

2335

24-
async def update_config(name: str, entity_data: dict):
25-
await bot_collection.update_one({"name": name}, {"$set": {"config": entity_data}})
36+
async def update_nlu_config(name: str, nlu_config: dict):
37+
await bot_collection.update_one(
38+
{"name": name}, {"$set": {"nlu_config": nlu_config}}
39+
)
2640

2741

2842
async def export_bot(name) -> Dict:

app/admin/train/routes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from fastapi import APIRouter, HTTPException, BackgroundTasks
22
from app.admin.intents import store
33
from app.dependencies import reload_dialogue_manager
4-
from app.bot.nlu.training import train_pipeline
4+
from app.bot.nlu.pipeline_utils import train_pipeline
55

66
router = APIRouter(prefix="/train", tags=["train"])
77

app/bot/dialogue_manager/dialogue_manager.py

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,8 @@
88
from app.bot.memory.memory_saver_mongo import MemorySaverMongo
99
from app.bot.memory.models import State
1010
from app.bot.nlu.pipeline import NLUPipeline
11-
from app.bot.nlu.featurizers import SpacyFeaturizer
12-
from app.bot.nlu.intent_classifiers import SklearnIntentClassifier
13-
from app.bot.nlu.entity_extractors import CRFEntityExtractor
11+
from app.bot.nlu.pipeline_utils import get_pipeline
1412
from app.bot.dialogue_manager.utils import SilentUndefined, split_sentence
15-
from app.admin.entities.store import list_synonyms
1613
from app.bot.dialogue_manager.models import (
1714
IntentModel,
1815
ParameterModel,
@@ -48,27 +45,21 @@ async def from_config(cls):
4845
Initialize DialogueManager with all required dependencies
4946
"""
5047

51-
synonyms = await list_synonyms()
52-
53-
# Initialize pipeline with components
54-
nlu_pipeline = NLUPipeline(
55-
[
56-
SpacyFeaturizer(app_config.SPACY_LANG_MODEL),
57-
SklearnIntentClassifier(),
58-
CRFEntityExtractor(synonyms),
59-
]
60-
)
61-
6248
# Load all intents and convert to domain models
6349
db_intents = await list_intents()
6450
intents = [IntentModel.from_db(intent) for intent in db_intents]
6551

52+
# Initialize pipeline with components
53+
nlu_pipeline = await get_pipeline()
54+
6655
# Get configuration
6756
fallback_intent_id = app_config.DEFAULT_FALLBACK_INTENT_NAME
6857

6958
# Get bot configuration
7059
bot = await get_bot("default")
71-
confidence_threshold = bot.config.get("confidence_threshold", 0.90)
60+
confidence_threshold = (
61+
bot.nlu_config.traditional_settings.intent_detection_threshold
62+
)
7263

7364
memory_saver = MemorySaverMongo(client)
7465

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
from .crf_entity_extractor import CRFEntityExtractor
2+
from .synonym_replacer import SynonymReplacer
23

3-
__all__ = ["CRFEntityExtractor"]
4+
__all__ = ["CRFEntityExtractor", "SynonymReplacer"]

app/bot/nlu/entity_extractors/crf_entity_extractor.py

Lines changed: 3 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import pycrfsuite
22
import logging
3-
from typing import Dict, Any, List, Optional
3+
from typing import Dict, Any, List
44
from app.bot.nlu.pipeline import NLUComponent
55
import os
66

@@ -13,23 +13,9 @@ class CRFEntityExtractor(NLUComponent):
1313
Performs NER training, prediction, model import/export
1414
"""
1515

16-
def __init__(self, synonyms: Optional[Dict[str, str]] = None):
17-
self.synonyms = synonyms or {}
16+
def __init__(self):
1817
self.tagger = None
1918

20-
def replace_synonyms(self, entities):
21-
"""
22-
replace extracted entity values with
23-
root word by matching with synonyms dict.
24-
:param entities:
25-
:return:
26-
"""
27-
for entity in entities.keys():
28-
entity_value = str(entities[entity])
29-
if entity_value.lower() in self.synonyms:
30-
entities[entity] = self.synonyms[entity_value.lower()]
31-
return entities
32-
3319
def extract_features(self, sent, i):
3420
"""
3521
Extract features for a given sentence
@@ -178,8 +164,7 @@ def predict(self, message):
178164
tagged_token = self.pos_tagger(spacy_doc)
179165
words = [token.text for token in spacy_doc]
180166
predicted_labels = self.tagger.tag(self.sent_to_features(tagged_token))
181-
extracted_entities = self.crf2json(zip(words, predicted_labels))
182-
return self.replace_synonyms(extracted_entities)
167+
return self.crf2json(zip(words, predicted_labels))
183168

184169
def pos_tagger(self, spacy_doc):
185170
"""
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import logging
2+
from typing import Dict, Any, Optional
3+
from app.bot.nlu.pipeline import NLUComponent
4+
5+
logger = logging.getLogger(__name__)
6+
7+
8+
class SynonymReplacer(NLUComponent):
9+
"""
10+
Replaces extracted entity values with their root words
11+
using a predefined synonyms dictionary.
12+
"""
13+
14+
def __init__(self, synonyms: Optional[Dict[str, str]] = None):
15+
self.synonyms = synonyms or {}
16+
17+
def replace_synonyms(self, entities: Dict[str, str]) -> Dict[str, str]:
18+
"""
19+
Replace extracted entity values with root words by matching with synonyms dict.
20+
:param entities: Dictionary of entity name to entity value mappings
21+
:return: Dictionary with replaced entity values where applicable
22+
"""
23+
for entity in entities.keys():
24+
entity_value = str(entities[entity])
25+
if entity_value.lower() in self.synonyms:
26+
entities[entity] = self.synonyms[entity_value.lower()]
27+
return entities
28+
29+
def train(self, training_data: Dict[str, Any], model_path: str) -> None:
30+
"""Nothing to train for synonym replacement."""
31+
pass
32+
33+
def load(self, model_path: str) -> bool:
34+
"""Nothing to load for synonym replacement."""
35+
return True
36+
37+
def process(self, message: Dict[str, Any]) -> Dict[str, Any]:
38+
"""Process a message by replacing entity values with their synonyms."""
39+
if not message.get("entities"):
40+
return message
41+
42+
entities = message["entities"]
43+
message["entities"] = self.replace_synonyms(entities)
44+
return message

app/bot/nlu/llm/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .zero_shot_nlu_openai import ZeroShotNLUOpenAI
2+
3+
__all__ = ["ZeroShotNLUOpenAI"]

0 commit comments

Comments
 (0)