diff --git a/guardrails/prompt/messages.py b/guardrails/prompt/messages.py index 8740cc46c..09d79a806 100644 --- a/guardrails/prompt/messages.py +++ b/guardrails/prompt/messages.py @@ -61,6 +61,8 @@ def format( ): """Format the messages using the given keyword arguments.""" formatted_messages = [] + # Precompute the keys for all kwargs, so repeated set intersection is O(1) + kwargs_keys = set(kwargs.keys()) for message in self.source: if isinstance(message["content"], str): msg_str = message["content"] @@ -68,7 +70,7 @@ def format( msg_str = message["content"]._source # Only use the keyword arguments that are present in the message. vars = get_template_variables(msg_str) - filtered_kwargs = {k: v for k, v in kwargs.items() if k in vars} + filtered_kwargs = {k: kwargs[k] for k in vars if k in kwargs_keys} # Return another instance of the class with the formatted message. formatted_message = Template(msg_str).safe_substitute(**filtered_kwargs) @@ -91,3 +93,9 @@ def substitute_constants(self, text): text = template.safe_substitute(**mapping) return text + + +def _filtered_kwargs(kwargs, vars): + # Efficient set intersection for needed kwargs + vars_set = set(vars) + return {k: kwargs[k] for k in vars_set & kwargs.keys()} diff --git a/guardrails/utils/templating_utils.py b/guardrails/utils/templating_utils.py index 2284c6f00..a96ee80af 100644 --- a/guardrails/utils/templating_utils.py +++ b/guardrails/utils/templating_utils.py @@ -1,12 +1,13 @@ -import collections from string import Template from typing import List +import re + +_TEMPLATE_VAR_PATTERN = re.compile(r"\${([_a-zA-Z][_a-zA-Z0-9]*)}") def get_template_variables(template: str) -> List[str]: + # If Template provides identifiers, use that (backward-compatible) if hasattr(Template, "get_identifiers"): return Template(template).get_identifiers() # type: ignore - else: - d = collections.defaultdict(str) - Template(template).safe_substitute(d) - return list(d.keys()) + # Fast regex extraction for variable names in the template string + return list({match.group(1) for match in _TEMPLATE_VAR_PATTERN.finditer(template)})