From 0f185371beb6f4ad7d389ddaf9400250de19054f Mon Sep 17 00:00:00 2001 From: Franco Culaciati Date: Sun, 23 Feb 2025 17:31:08 -0300 Subject: [PATCH 1/4] Add PromptStr class to handle prompt interpolation --- .../jupyter_ai_magics/magics.py | 68 +++++++++++++++++-- 1 file changed, 63 insertions(+), 5 deletions(-) diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py b/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py index ebfdf7b4a..3d1696cf7 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py @@ -105,6 +105,64 @@ def _repr_mimebundle_(self, include=None, exclude=None): MULTIENV_REQUIRES = "Requires environment variables:" + +class PromptStr(str): + """ + A string subclass that processes its content to support a custom + placeholder delimiter. Custom placeholders are marked with "@{...}". + + When format() or format_map() is called, the instance is first processed: + - Custom placeholders (e.g. "@{var}") are converted into standard + placeholders ("{var}") for interpolation. + - All other literal curly braces are doubled (e.g. "{" becomes "{{") + so that they are preserved literally. + """ + + def __init__(self, text): + super().__init__() + self._template = self._process_template(text) + + @staticmethod + def _process_template(template: str) -> str: + """ + Process the input template so that: + - Any custom placeholder of the form "@{...}" is converted into + a normal placeholder "{...}". + - All other literal curly braces are doubled so that they remain + unchanged during formatting. + + Assumes that the custom placeholder does not contain nested braces. + """ + # Pattern to match custom placeholders: "@{...}" where ... has no braces. + pattern = r'@{([^{}]+)}' + tokens = [] + + def token_replacer(match): + inner = match.group(1) + # If by any chance the inner content contains braces, fail. + if "{" in inner or "}" in inner: + raise ValueError("Nested custom placeholders are not allowed") + tokens.append(inner) + # Replace with a temporary unique token. + return f'<<<{len(tokens)-1}>>>' + + # Replace custom placeholders with temporary tokens. + template_with_tokens = re.sub(pattern, token_replacer, template) + # Escape all remaining literal braces by doubling them. + escaped = template_with_tokens.replace("{", "{{").replace("}", "}}") + # Replace our temporary tokens with normal placeholders. + for i, token in enumerate(tokens): + escaped = escaped.replace(f'<<<{i}>>>', f'{{{token}}}') + return escaped + + + def format(self, *args, **kwargs): + return self._template.format(*args, **kwargs) + + def format_map(self, mapping): + return self._template.format_map(mapping) + + class FormatDict(dict): """Subclass of dict to be passed to str#format(). Suppresses KeyError and leaves replacement field unchanged if replacement field is not associated @@ -588,8 +646,8 @@ def run_ai_cell(self, args: CellArgs, prompt: str): prompt = provider.get_prompt_template(args.format).format(prompt=prompt) # interpolate user namespace into prompt - ip = self.shell - prompt = prompt.format_map(FormatDict(ip.user_ns)) + # ip = self.shell + # prompt = prompt.format_map(FormatDict(ip.user_ns)) context = self.transcript[-2 * self.max_history :] if self.max_history else [] if provider.is_chat_provider: @@ -676,10 +734,10 @@ def ai(self, line, cell=None): subcommands.""" ) - prompt = cell.strip() + prompt = PromptStr(cell.strip()) # interpolate user namespace into prompt ip = self.shell - prompt = prompt.format_map(FormatDict(ip.user_ns)) + prompt = prompt.format_map(ip.user_ns) - return self.run_ai_cell(args, prompt) + return self.run_ai_cell(args, prompt) \ No newline at end of file From 151fdd42233c356c0d519d35bbb6e58a1bfd7cc2 Mon Sep 17 00:00:00 2001 From: Franco Culaciati Date: Sun, 23 Feb 2025 17:34:20 -0300 Subject: [PATCH 2/4] Remove FormatDict and extra interpolation step --- .../jupyter_ai_magics/magics.py | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py b/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py index 3d1696cf7..8eb41151f 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py @@ -163,19 +163,6 @@ def format_map(self, mapping): return self._template.format_map(mapping) -class FormatDict(dict): - """Subclass of dict to be passed to str#format(). Suppresses KeyError and - leaves replacement field unchanged if replacement field is not associated - with a value.""" - - def __missing__(self, key): - return key.join("{}") - - -class EnvironmentError(BaseException): - pass - - class CellMagicError(BaseException): pass @@ -645,10 +632,6 @@ def run_ai_cell(self, args: CellArgs, prompt: str): # Apply a prompt template. prompt = provider.get_prompt_template(args.format).format(prompt=prompt) - # interpolate user namespace into prompt - # ip = self.shell - # prompt = prompt.format_map(FormatDict(ip.user_ns)) - context = self.transcript[-2 * self.max_history :] if self.max_history else [] if provider.is_chat_provider: result = provider.generate([[*context, HumanMessage(content=prompt)]]) From 7164960100d4b3a9e25d300c20217f8c4249b7b8 Mon Sep 17 00:00:00 2001 From: Franco Culaciati Date: Sun, 23 Feb 2025 19:01:15 -0300 Subject: [PATCH 3/4] Improve error handling --- .../jupyter_ai_magics/magics.py | 23 +++++++++---------- 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py b/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py index 8eb41151f..afb572e2e 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py @@ -116,10 +116,11 @@ class PromptStr(str): placeholders ("{var}") for interpolation. - All other literal curly braces are doubled (e.g. "{" becomes "{{") so that they are preserved literally. + + If any custom placeholder contains additional curly braces (i.e. nested + braces), a ValueError is raised. """ - def __init__(self, text): - super().__init__() self._template = self._process_template(text) @staticmethod @@ -132,6 +133,8 @@ def _process_template(template: str) -> str: unchanged during formatting. Assumes that the custom placeholder does not contain nested braces. + If nested or extra curly braces are found within a custom placeholder, + a ValueError is raised. """ # Pattern to match custom placeholders: "@{...}" where ... has no braces. pattern = r'@{([^{}]+)}' @@ -139,26 +142,22 @@ def _process_template(template: str) -> str: def token_replacer(match): inner = match.group(1) - # If by any chance the inner content contains braces, fail. - if "{" in inner or "}" in inner: - raise ValueError("Nested custom placeholders are not allowed") + assert ("{" not in inner) and ("}" not in inner) tokens.append(inner) - # Replace with a temporary unique token. return f'<<<{len(tokens)-1}>>>' - # Replace custom placeholders with temporary tokens. template_with_tokens = re.sub(pattern, token_replacer, template) - # Escape all remaining literal braces by doubling them. + if "@{" in template_with_tokens: + raise ValueError("Curly braces are not allowed inside custom placeholders.") + escaped = template_with_tokens.replace("{", "{{").replace("}", "}}") - # Replace our temporary tokens with normal placeholders. for i, token in enumerate(tokens): escaped = escaped.replace(f'<<<{i}>>>', f'{{{token}}}') return escaped - def format(self, *args, **kwargs): return self._template.format(*args, **kwargs) - + def format_map(self, mapping): return self._template.format_map(mapping) @@ -631,7 +630,7 @@ def run_ai_cell(self, args: CellArgs, prompt: str): # Apply a prompt template. prompt = provider.get_prompt_template(args.format).format(prompt=prompt) - + context = self.transcript[-2 * self.max_history :] if self.max_history else [] if provider.is_chat_provider: result = provider.generate([[*context, HumanMessage(content=prompt)]]) From 6253ff34ccfe6981ddd97154c539269b7b735e83 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 23 Feb 2025 22:12:27 +0000 Subject: [PATCH 4/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../jupyter_ai_magics/magics.py | 28 +++++++++---------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py b/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py index afb572e2e..8acb40c6d 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/magics.py @@ -105,21 +105,21 @@ def _repr_mimebundle_(self, include=None, exclude=None): MULTIENV_REQUIRES = "Requires environment variables:" - class PromptStr(str): """ A string subclass that processes its content to support a custom placeholder delimiter. Custom placeholders are marked with "@{...}". - + When format() or format_map() is called, the instance is first processed: - Custom placeholders (e.g. "@{var}") are converted into standard placeholders ("{var}") for interpolation. - All other literal curly braces are doubled (e.g. "{" becomes "{{") so that they are preserved literally. - + If any custom placeholder contains additional curly braces (i.e. nested braces), a ValueError is raised. """ + def __init__(self, text): self._template = self._process_template(text) @@ -131,33 +131,33 @@ def _process_template(template: str) -> str: a normal placeholder "{...}". - All other literal curly braces are doubled so that they remain unchanged during formatting. - + Assumes that the custom placeholder does not contain nested braces. If nested or extra curly braces are found within a custom placeholder, a ValueError is raised. """ # Pattern to match custom placeholders: "@{...}" where ... has no braces. - pattern = r'@{([^{}]+)}' + pattern = r"@{([^{}]+)}" tokens = [] - + def token_replacer(match): inner = match.group(1) assert ("{" not in inner) and ("}" not in inner) tokens.append(inner) - return f'<<<{len(tokens)-1}>>>' - + return f"<<<{len(tokens)-1}>>>" + template_with_tokens = re.sub(pattern, token_replacer, template) if "@{" in template_with_tokens: raise ValueError("Curly braces are not allowed inside custom placeholders.") - + escaped = template_with_tokens.replace("{", "{{").replace("}", "}}") for i, token in enumerate(tokens): - escaped = escaped.replace(f'<<<{i}>>>', f'{{{token}}}') + escaped = escaped.replace(f"<<<{i}>>>", f"{{{token}}}") return escaped - + def format(self, *args, **kwargs): return self._template.format(*args, **kwargs) - + def format_map(self, mapping): return self._template.format_map(mapping) @@ -630,7 +630,7 @@ def run_ai_cell(self, args: CellArgs, prompt: str): # Apply a prompt template. prompt = provider.get_prompt_template(args.format).format(prompt=prompt) - + context = self.transcript[-2 * self.max_history :] if self.max_history else [] if provider.is_chat_provider: result = provider.generate([[*context, HumanMessage(content=prompt)]]) @@ -722,4 +722,4 @@ def ai(self, line, cell=None): ip = self.shell prompt = prompt.format_map(ip.user_ns) - return self.run_ai_cell(args, prompt) \ No newline at end of file + return self.run_ai_cell(args, prompt)