Skip to content
This repository was archived by the owner on Apr 28, 2021. It is now read-only.

Commit d16e75f

Browse files
author
pheel
authored
Merge pull request #26 from Amirali-Shirkh/feature/url-rewriting-for-images
feat: support for image url text replacements in BotfrontTemplatedNaturalLanguageGenerator and GraphQLNaturalLanguageGenerator
2 parents e5dac90 + ef3792c commit d16e75f

File tree

4 files changed

+41
-3
lines changed

4 files changed

+41
-3
lines changed

rasa_addons/core/nlg/bftemplate.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import copy
22
import logging
33
from collections import defaultdict
4-
4+
from rasa_addons.core.nlg.nlg_helper import rewrite_url
55
from rasa.core.trackers import DialogueStateTracker
66
from typing import Text, Any, Dict, Optional, List
77

@@ -14,6 +14,8 @@
1414
class BotfrontTemplatedNaturalLanguageGenerator(NaturalLanguageGenerator):
1515
def __init__(self, **kwargs) -> None:
1616
domain = kwargs.get("domain")
17+
templated_endpoint = kwargs.get("endpoint_config")
18+
self.url_substitution_pattern = templated_endpoint.kwargs.get('url_substitutions') or []
1719
self.templates = domain.templates if domain else []
1820

1921
def _templates_for_utter_action(self, utter_action, output_channel, **kwargs):
@@ -85,11 +87,13 @@ async def generate(
8587
fallback_language=fallback_language,
8688
)
8789
if "language" in message: del message["language"]
90+
rewrite_url(message, self.url_substitution_pattern)
8891
metadata = message.pop("metadata", {}) or {}
8992
for key in metadata: message[key] = metadata[key]
9093

9194
return message
9295

96+
9397
def generate_from_slots(
9498
self,
9599
template_name: Text,

rasa_addons/core/nlg/constants.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
from enum import Enum
2+
3+
class NlgEnum(Enum):
4+
"""A class to represent constants for enum values for nlg classes."""
5+
6+
IMAGE = "image"
7+
IMAGE_URL = "image_url"
8+
ELEMENTS = "elements"
9+
PATTERN = "pattern"
10+
REPLACEMENT = "replacement"

rasa_addons/core/nlg/graphql.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import logging
22
from typing import Text, Any, Dict, Optional, List
3-
3+
from rasa_addons.core.nlg.nlg_helper import rewrite_url
44
from rasa.core.constants import DEFAULT_REQUEST_TIMEOUT
55
from rasa.core.nlg.generator import NaturalLanguageGenerator
66
from rasa.core.trackers import DialogueStateTracker, EventVerbosity
@@ -100,13 +100,13 @@ def nlg_request_format(
100100
"channel": {"name": output_channel},
101101
}
102102

103-
104103
class GraphQLNaturalLanguageGenerator(NaturalLanguageGenerator):
105104
"""Like Rasa's CallbackNLG, but queries Botfront's GraphQL endpoint"""
106105

107106
def __init__(self, **kwargs) -> None:
108107
endpoint_config = kwargs.get("endpoint_config")
109108
self.nlg_endpoint = endpoint_config
109+
self.url_substitution_pattern = endpoint_config.kwargs.get('url_substitutions') or []
110110

111111
async def generate(
112112
self,
@@ -152,6 +152,7 @@ async def generate(
152152
", ".join([e.get("message") for e in response.get("errors")])
153153
)
154154
response = response.get("data", {}).get("getResponse", {})
155+
rewrite_url(response, self.url_substitution_pattern)
155156
if "customText" in response:
156157
response["text"] = response.pop("customText")
157158
if "customImage" in response:

rasa_addons/core/nlg/nlg_helper.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import re
2+
from rasa_addons.core.nlg.constants import NlgEnum
3+
4+
def rewrite_url(message: dict, url_substitution_pattern: list):
5+
"""Rewrite image url with the pattern found in endpoint."""
6+
7+
if url_substitution_pattern:
8+
if NlgEnum.IMAGE.value in message.keys():
9+
substitute(message, NlgEnum.IMAGE.value, url_substitution_pattern)
10+
elif NlgEnum.ELEMENTS.value in message.keys():
11+
for element in message[NlgEnum.ELEMENTS.value]:
12+
substitute(element, NlgEnum.IMAGE_URL.value, url_substitution_pattern)
13+
14+
def substitute(message: dict, key: str, url_substitution_pattern: list):
15+
"""Substitute rewritten url."""
16+
17+
url = message[key]
18+
for item in url_substitution_pattern:
19+
substitute = re.sub(item.get(NlgEnum.PATTERN.value), item.get(NlgEnum.REPLACEMENT.value), message[key])
20+
if substitute != url:
21+
message[key] = substitute
22+
return
23+
return

0 commit comments

Comments
 (0)