diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py b/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py index ebfdf7b4a..8acb40c6d 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py @@ -105,17 +105,61 @@ def _repr_mimebundle_(self, include=None, exclude=None): MULTIENV_REQUIRES = "Requires environment variables:" -class FormatDict(dict): - """Subclass of dict to be passed to str#format(). Suppresses KeyError and - leaves replacement field unchanged if replacement field is not associated - with a value.""" - - def __missing__(self, key): - return key.join("{}") - - -class EnvironmentError(BaseException): - pass +class PromptStr(str): + """ + A string subclass that processes its content to support a custom + placeholder delimiter. Custom placeholders are marked with "@{...}". + + When format() or format_map() is called, the instance is first processed: + - Custom placeholders (e.g. "@{var}") are converted into standard + placeholders ("{var}") for interpolation. + - All other literal curly braces are doubled (e.g. "{" becomes "{{") + so that they are preserved literally. + + If any custom placeholder contains additional curly braces (i.e. nested + braces), a ValueError is raised. + """ + + def __init__(self, text): + self._template = self._process_template(text) + + @staticmethod + def _process_template(template: str) -> str: + """ + Process the input template so that: + - Any custom placeholder of the form "@{...}" is converted into + a normal placeholder "{...}". + - All other literal curly braces are doubled so that they remain + unchanged during formatting. + + Assumes that the custom placeholder does not contain nested braces. + If nested or extra curly braces are found within a custom placeholder, + a ValueError is raised. + """ + # Pattern to match custom placeholders: "@{...}" where ... has no braces. + pattern = r"@{([^{}]+)}" + tokens = [] + + def token_replacer(match): + inner = match.group(1) + assert ("{" not in inner) and ("}" not in inner) + tokens.append(inner) + return f"<<<{len(tokens)-1}>>>" + + template_with_tokens = re.sub(pattern, token_replacer, template) + if "@{" in template_with_tokens: + raise ValueError("Curly braces are not allowed inside custom placeholders.") + + escaped = template_with_tokens.replace("{", "{{").replace("}", "}}") + for i, token in enumerate(tokens): + escaped = escaped.replace(f"<<<{i}>>>", f"{{{token}}}") + return escaped + + def format(self, *args, **kwargs): + return self._template.format(*args, **kwargs) + + def format_map(self, mapping): + return self._template.format_map(mapping) class CellMagicError(BaseException): @@ -587,10 +631,6 @@ def run_ai_cell(self, args: CellArgs, prompt: str): # Apply a prompt template. prompt = provider.get_prompt_template(args.format).format(prompt=prompt) - # interpolate user namespace into prompt - ip = self.shell - prompt = prompt.format_map(FormatDict(ip.user_ns)) - context = self.transcript[-2 * self.max_history :] if self.max_history else [] if provider.is_chat_provider: result = provider.generate([[*context, HumanMessage(content=prompt)]]) @@ -676,10 +716,10 @@ def ai(self, line, cell=None): subcommands.""" ) - prompt = cell.strip() + prompt = PromptStr(cell.strip()) # interpolate user namespace into prompt ip = self.shell - prompt = prompt.format_map(FormatDict(ip.user_ns)) + prompt = prompt.format_map(ip.user_ns) return self.run_ai_cell(args, prompt)