Skip to content

Commit 51e7b34

Browse files
authored
Merge pull request #665 from praekeltfoundation/nlu-intent-labelling
Nluclassifier for feedback labelling
2 parents 8640f97 + 1371784 commit 51e7b34

File tree

7 files changed

+495
-0
lines changed

7 files changed

+495
-0
lines changed

ndoh_hub/settings.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
"changes",
6262
"eventstore",
6363
"aaq",
64+
"nluclassifier",
6465
)
6566

6667

@@ -110,6 +111,11 @@
110111
"handlers": ["console"],
111112
"propagate": False,
112113
},
114+
"nluclassifier": {
115+
"handlers": ["console"],
116+
"level": "INFO",
117+
"propagate": False,
118+
},
113119
},
114120
}
115121

ndoh_hub/urls.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@
9494
path("api/v3/", include(v3router.urls)),
9595
path("api/v4/", include(v4router.urls)),
9696
path("api/v5/", include(v5router.urls)),
97+
path("nluclassifier/", include("nluclassifier.urls")),
9798
path(
9899
"metrics", internal_only(django_prometheus.ExportToDjangoView), name="metrics"
99100
),

nluclassifier/tasks.py

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
import logging
2+
from urllib.parse import urljoin
3+
4+
import requests
5+
from celery.exceptions import SoftTimeLimitExceeded
6+
from django.conf import settings
7+
from requests.auth import HTTPBasicAuth
8+
from requests.exceptions import RequestException
9+
10+
from ndoh_hub.celery import app
11+
12+
logger = logging.getLogger(__name__)
13+
14+
CELERY_TASK_OPTIONS = {
15+
"autoretry_for": (RequestException, SoftTimeLimitExceeded),
16+
"retry_backoff": True,
17+
"max_retries": 15,
18+
"acks_late": True,
19+
"soft_time_limit": 10,
20+
"time_limit": 15,
21+
}
22+
23+
24+
@app.task(**CELERY_TASK_OPTIONS)
25+
def process_feedback_for_labeling(message_id: str, inbound_message: str):
26+
"""
27+
Call the NLU endpoint for feedback classification.
28+
LLabel the message in Turn using the detected intent.
29+
"""
30+
31+
nlu_label = None
32+
33+
NLU_URL = settings.INTENT_CLASSIFIER_URL
34+
NLU_USER = settings.INTENT_CLASSIFIER_USER
35+
NLU_PASS = settings.INTENT_CLASSIFIER_PASS
36+
# fmt: off
37+
try:
38+
nlu_endpoint = urljoin(NLU_URL, "/nlu/feedback/")
39+
params = {"question": inbound_message}
40+
41+
logger.info(f"Calling NLU for message {message_id} at {nlu_endpoint}")
42+
43+
response = requests.get(
44+
nlu_endpoint,
45+
params=params,
46+
auth=HTTPBasicAuth(NLU_USER, NLU_PASS),
47+
timeout=CELERY_TASK_OPTIONS["soft_time_limit"],
48+
)
49+
response.raise_for_status()
50+
51+
nlu_label = response.json().get("intent")
52+
53+
logger.info(f"NLU Intent for {message_id}: {nlu_label}")
54+
55+
except RequestException as e:
56+
logger.warning(
57+
f"NLU service failed for message {message_id}. "
58+
f"Retrying. Error: {e}"
59+
)
60+
raise
61+
except Exception as e:
62+
logger.error(
63+
f"NLU non-retriable failure for message {message_id}: "
64+
f"{e}"
65+
)
66+
return
67+
68+
if nlu_label and nlu_label.lower() in ("compliment", "complaint"):
69+
70+
TURN_TOKEN = settings.TURN_TOKEN
71+
TURN_URL = settings.TURN_URL
72+
73+
turn_endpoint = urljoin(TURN_URL, f"v1/messages/{message_id}/labels")
74+
75+
label_payload = {"labels": [nlu_label.lower()]}
76+
77+
headers = {
78+
"Authorization": f"Bearer {TURN_TOKEN}",
79+
"Accept": "application/vnd.v1+json",
80+
"Content-Type": "application/json",
81+
}
82+
83+
try:
84+
logger.info(
85+
f"Labeling message {message_id} in Turn at {turn_endpoint} "
86+
f"with label: {nlu_label}"
87+
)
88+
89+
turn_response = requests.post(
90+
turn_endpoint,
91+
json=label_payload,
92+
headers=headers,
93+
timeout=CELERY_TASK_OPTIONS["soft_time_limit"],
94+
)
95+
96+
turn_response.raise_for_status()
97+
98+
logger.info(
99+
f"Successfully labeled message {message_id} "
100+
f"as: {nlu_label}"
101+
)
102+
103+
except RequestException as e:
104+
105+
logger.warning(
106+
f"Turn API failed to label message {message_id}. "
107+
f"Retrying. Error: {e}"
108+
)
109+
raise
110+
except Exception as e:
111+
logger.error(
112+
f"Turn API unrecoverable error for message {message_id}: {e}"
113+
)
114+
return
115+
else:
116+
logger.info(
117+
f"NLU result was '{nlu_label}' (not 'Compliment' or 'Complaint'). "
118+
f"No Turn label applied for message {message_id}."
119+
)

nluclassifier/tests/test_tasks.py

Lines changed: 202 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,202 @@
1+
import logging
2+
from unittest import mock
3+
from urllib.parse import urljoin
4+
5+
import requests.exceptions
6+
from django.conf import settings
7+
from django.test import TestCase
8+
9+
from nluclassifier.tasks import process_feedback_for_labeling
10+
11+
logger = logging.getLogger(__name__)
12+
13+
14+
# fmt: off
15+
class NLUClassifierTaskTests(TestCase):
16+
"""
17+
Tests for the Celery task that handles NLU classification
18+
and Turn message labeling. The tests verify API calls and
19+
ensure RequestExceptions are re-raised for Celery's autoretry
20+
"""
21+
22+
def setUp(self):
23+
self.settings = settings
24+
25+
self.message_id = "12345"
26+
self.inbound_message = "I am not happy with your service."
27+
28+
self.expected_intent = "COMPLAINT"
29+
30+
def test_successful_classification_and_labeling(self):
31+
"""
32+
Tests the end-to-end flow where NLU classifies the message and
33+
Turn successfully applies the label
34+
"""
35+
nlu_success_mock = mock.Mock(status_code=200)
36+
nlu_success_mock.json.return_value = {
37+
"intent": self.expected_intent,
38+
"model_version": "2025-09-29-v1",
39+
"parent_label": "FEEDBACK",
40+
"probability": 0.3805,
41+
"review_status": "NEEDS_REVIEW",
42+
}
43+
nlu_success_mock.raise_for_status.return_value = None
44+
45+
turn_success_mock = mock.Mock(status_code=202)
46+
turn_success_mock.raise_for_status.return_value = None
47+
48+
with (
49+
mock.patch("requests.get") as mock_get,
50+
mock.patch("requests.post") as mock_post,
51+
):
52+
53+
mock_get.return_value = nlu_success_mock
54+
mock_post.return_value = turn_success_mock
55+
56+
result = process_feedback_for_labeling(
57+
self.message_id, self.inbound_message
58+
)
59+
60+
expected_nlu_endpoint = urljoin(
61+
self.settings.INTENT_CLASSIFIER_URL, "/nlu/feedback/"
62+
)
63+
64+
self.assertEqual(mock_get.call_count, 1)
65+
nlu_call_args = mock_get.call_args
66+
self.assertEqual(nlu_call_args[0][0], expected_nlu_endpoint)
67+
self.assertEqual(
68+
nlu_call_args[1]["params"]["question"], self.inbound_message
69+
)
70+
71+
self.assertEqual(mock_post.call_count, 1)
72+
turn_call_args = mock_post.call_args
73+
self.assertIn(
74+
f"messages/{self.message_id}/labels",
75+
turn_call_args[0][0]
76+
)
77+
self.assertEqual(
78+
turn_call_args[1]["headers"]["Authorization"],
79+
f"Bearer {self.settings.TURN_TOKEN}",
80+
)
81+
82+
turn_payload = turn_call_args[1]["json"]
83+
self.assertEqual(
84+
turn_payload["labels"],
85+
[self.expected_intent.lower()]
86+
)
87+
88+
self.assertTrue(result is None)
89+
90+
def test_nlu_api_failure_raises_exception_for_celery_retry(self):
91+
"""
92+
Tests that if the NLU API returns a failure,
93+
a RequestException is raised
94+
"""
95+
nlu_fail_mock = mock.Mock(status_code=404)
96+
error = requests.exceptions.HTTPError("404 Not Found")
97+
nlu_fail_mock.raise_for_status.side_effect = error
98+
99+
with (
100+
mock.patch("requests.get") as mock_get,
101+
mock.patch("requests.post") as mock_post,
102+
self.assertLogs(
103+
"nluclassifier.tasks", level="WARNING") as log_context,
104+
):
105+
106+
mock_get.return_value = nlu_fail_mock
107+
108+
with self.assertRaises(requests.exceptions.HTTPError):
109+
process_feedback_for_labeling(
110+
self.message_id, self.inbound_message)
111+
112+
self.assertEqual(mock_get.call_count, 1)
113+
114+
self.assertEqual(mock_post.call_count, 0)
115+
116+
self.assertTrue(
117+
any(
118+
"NLU service failed for message" in output
119+
for output in log_context.output
120+
)
121+
)
122+
123+
def test_turn_api_failure_raises_exception_for_retry(self):
124+
"""
125+
Tests that if NLU succeeds but the Turn API fails,
126+
a RequestException is rraised to trigger a retry.
127+
"""
128+
nlu_success_mock = mock.Mock(status_code=200)
129+
nlu_success_mock.json.return_value = {
130+
"intent": self.expected_intent,
131+
"model_version": "2025-09-29-v1",
132+
"parent_label": "FEEDBACK",
133+
"probability": 0.3805,
134+
"review_status": "NEEDS_REVIEW",
135+
}
136+
nlu_success_mock.raise_for_status.return_value = None
137+
138+
turn_fail_mock = mock.Mock(status_code=500)
139+
error = requests.exceptions.HTTPError("500 Server Error")
140+
turn_fail_mock.raise_for_status.side_effect = error
141+
142+
with (
143+
mock.patch("requests.get") as mock_get,
144+
mock.patch("requests.post") as mock_post,
145+
self.assertLogs(
146+
"nluclassifier.tasks", level="WARNING") as log_context,
147+
):
148+
mock_get.return_value = nlu_success_mock
149+
mock_post.return_value = turn_fail_mock
150+
151+
with self.assertRaises(requests.exceptions.HTTPError):
152+
process_feedback_for_labeling(
153+
self.message_id, self.inbound_message)
154+
155+
self.assertEqual(mock_get.call_count, 1)
156+
157+
self.assertEqual(mock_post.call_count, 1)
158+
159+
self.assertTrue(
160+
any("Turn API failed to label message" in output
161+
for output in log_context.output)
162+
)
163+
164+
def test_no_label_applied_on_unhandled_intent(self):
165+
"""
166+
Tests that if NLU returns an intent that is not
167+
'compliment' or 'complaint'.
168+
No Turn API call is made.
169+
"""
170+
nlu_success_mock = mock.Mock(status_code=200)
171+
172+
nlu_success_mock.json.return_value = {
173+
"intent": "None",
174+
"model_version": "2025-09-29-v1",
175+
"parent_label": "SENSITIVE_EXIT",
176+
"probability": 0.3357,
177+
"review_status": "NEEDS_REVIEW",
178+
}
179+
nlu_success_mock.raise_for_status.return_value = None
180+
181+
with (
182+
mock.patch("requests.get") as mock_get,
183+
mock.patch("requests.post") as mock_post,
184+
self.assertLogs(
185+
"nluclassifier.tasks", level="INFO") as log_context,
186+
):
187+
188+
mock_get.return_value = nlu_success_mock
189+
190+
result = process_feedback_for_labeling(
191+
self.message_id, self.inbound_message
192+
)
193+
194+
self.assertEqual(mock_get.call_count, 1)
195+
196+
self.assertEqual(mock_post.call_count, 0)
197+
198+
self.assertTrue(
199+
any("No Turn label applied" in output
200+
for output in log_context.output)
201+
)
202+
self.assertTrue(result is None)

0 commit comments

Comments
 (0)