Skip to content

Commit 1e6b2f3

Browse files
Fix KERNEL_POD_NAME substitution to avoid SSTI (#1412)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 152c20f commit 1e6b2f3

File tree

4 files changed

+287
-9
lines changed

4 files changed

+287
-9
lines changed

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ clean-env: ## Remove conda env
6767
lint: ## Check code style
6868
@pip install -q -e ".[lint]"
6969
@pip install -q pipx
70-
ruff .
70+
ruff check .
7171
black --check --diff --color .
7272
mdformat --check *.md
7373
pipx run 'validate-pyproject[all]' pyproject.toml

docs/source/users/kernel-envs.md

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,9 +76,12 @@ There are several supported `KERNEL_` variables that the Enterprise Gateway serv
7676
it is the user's responsibility that KERNEL_POD_NAME is unique relative to
7777
any pods in the target namespace. In addition, the pod must NOT exist -
7878
unlike the case if KERNEL_NAMESPACE is provided. The KERNEL_POD_NAME can
79-
also be provided as a jinja2 template string
79+
also be provided as a jinja2 template formatted string
8080
(e.g "{{ kernel_prefix }}-{{ kernel_id | replace('-', '') }}")
81-
which will be evaluated against existing list of environment variables.
81+
which will be processed for safe substitution against existing list
82+
of environment variables. In case of invalid template (e.g. missing variables)
83+
it will fall back to original way to calculate the pod name using
84+
KERNEL_USERNAME - KERNEL_ID.
8285
8386
KERNEL_REMOTE_HOST=<remote host name>
8487
DistributedProcessProxy only. When specified, this value will override the

enterprise_gateway/services/processproxies/k8s.py

Lines changed: 50 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
from typing import Any
1212

1313
import urllib3
14-
from jinja2 import BaseLoader, Environment
1514
from kubernetes import client, config
1615

1716
from ..kernels.remotemanager import RemoteKernelManager
@@ -216,6 +215,42 @@ def terminate_container_resources(self) -> bool | None:
216215

217216
return result
218217

218+
def _safe_template_substitute(self, template_str: str, variables: dict) -> str | None:
219+
"""
220+
Safely substitute variables in Jinja2-style template syntax.
221+
Only supports simple variable substitution: {{ variable_name }}
222+
Logs missing variables and returns None if any are missing.
223+
"""
224+
# Pattern to match {{ variable_name }} with optional whitespace
225+
# Explicitly exclude variables starting with underscore to prevent magic method attacks
226+
pattern = r'\{\{\s*([a-zA-Z][a-zA-Z0-9_]*)\s*\}\}'
227+
missing_vars = []
228+
229+
def replace_var(match):
230+
var_name = match.group(1)
231+
if var_name in variables:
232+
return str(variables[var_name])
233+
else:
234+
missing_vars.append(var_name)
235+
return match.group(0) # Keep original placeholder
236+
237+
result = re.sub(pattern, replace_var, template_str)
238+
239+
# Check if there are any remaining {{ }} patterns that didn't match our simple pattern
240+
# This catches malicious templates like {{ foo.__class__ }} or {{ 1+1 }}
241+
if '{{' in result and '}}' in result:
242+
self.log.warning(
243+
"Invalid template syntax detected in KERNEL_POD_NAME: contains unsupported expressions"
244+
)
245+
return None
246+
247+
# Log missing variables and return None if any are missing
248+
if missing_vars:
249+
self.log.warning(f"Template variables not found in KERNEL_POD_NAME: {missing_vars}")
250+
return None # Signal caller to use default
251+
252+
return result
253+
219254
def _determine_kernel_pod_name(self, **kwargs: dict[str, Any] | None) -> str:
220255
pod_name = kwargs["env"].get("KERNEL_POD_NAME")
221256

@@ -224,16 +259,25 @@ def _determine_kernel_pod_name(self, **kwargs: dict[str, Any] | None) -> str:
224259
else:
225260
self.log.debug(f"Processing KERNEL_POD_NAME based on env var => {pod_name}")
226261
if "{{" in pod_name and "}}" in pod_name:
227-
self.log.debug("Processing KERNEL_POD_NAME as jinja template")
228-
# Create Jinja2 environment
262+
self.log.debug("Processing KERNEL_POD_NAME template variables")
229263
keywords = {}
230264
for name, value in kwargs["env"].items():
231265
if name.startswith("KERNEL_"):
232266
keywords[name.lower()] = value
233267
keywords["kernel_id"] = self.kernel_id
234-
self.log.debug("Processing pod_name jinja template")
235-
env = Environment(loader=BaseLoader(), autoescape=True)
236-
pod_name = env.from_string(pod_name).render(**keywords)
268+
269+
# Safe template substitution with fallback
270+
substituted = self._safe_template_substitute(pod_name, keywords)
271+
if substituted is None:
272+
# Fall back to default if template variables are missing
273+
self.log.warning(
274+
"Falling back to default pod name due to missing template variables"
275+
)
276+
pod_name = (
277+
KernelSessionManager.get_kernel_username(**kwargs) + "-" + self.kernel_id
278+
)
279+
else:
280+
pod_name = substituted
237281

238282
# Rewrite pod_name to be compatible with DNS name convention
239283
# And put back into env since kernel needs this
Lines changed: 231 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,231 @@
1+
# Copyright (c) Jupyter Development Team.
2+
# Distributed under the terms of the Modified BSD License.
3+
"""Tests for Kubernetes process proxy security fixes."""
4+
5+
import unittest
6+
from unittest.mock import Mock, patch
7+
8+
# Mock Kubernetes configuration before importing the module
9+
with patch('kubernetes.config.load_incluster_config'), patch('kubernetes.config.load_kube_config'):
10+
from enterprise_gateway.services.processproxies.k8s import KubernetesProcessProxy
11+
12+
13+
class TestKubernetesProcessProxy(unittest.TestCase):
14+
"""Test secure template substitution in Kubernetes process proxy."""
15+
16+
def setUp(self):
17+
"""Set up test fixtures."""
18+
self.mock_kernel_manager = Mock()
19+
self.mock_kernel_manager.get_kernel_username.return_value = "testuser"
20+
self.mock_kernel_manager.port_range = "0..0" # Mock port range
21+
22+
# Mock proxy config
23+
self.proxy_config = {"kernel_id": "test-kernel-id", "kernel_name": "python3"}
24+
25+
# Mock KernelSessionManager methods
26+
with patch(
27+
'enterprise_gateway.services.processproxies.k8s.KernelSessionManager'
28+
) as mock_session_manager:
29+
mock_session_manager.get_kernel_username.return_value = "testuser"
30+
self.proxy = KubernetesProcessProxy(self.mock_kernel_manager, self.proxy_config)
31+
self.proxy.kernel_id = "test-kernel-id"
32+
33+
def test_valid_template_substitution(self):
34+
"""Test valid template variable substitution."""
35+
test_cases = [
36+
# Basic variable substitution
37+
("{{ kernel_id }}", {"kernel_id": "test-123"}, "test-123"),
38+
# Multiple variables
39+
(
40+
"{{ kernel_namespace }}-{{ kernel_id }}",
41+
{"kernel_namespace": "default", "kernel_id": "test-123"},
42+
"default-test-123",
43+
),
44+
# Variables with underscores
45+
("{{ kernel_image_pull_policy }}", {"kernel_image_pull_policy": "Always"}, "Always"),
46+
# Whitespace handling
47+
("{{ kernel_id }}", {"kernel_id": "test-123"}, "test-123"),
48+
]
49+
50+
for template, variables, expected in test_cases:
51+
with self.subTest(template=template):
52+
result = self.proxy._safe_template_substitute(template, variables)
53+
self.assertEqual(result, expected)
54+
55+
def test_missing_variables_fallback(self):
56+
# Test the full pod name determination process
57+
kwargs = {
58+
"env": {
59+
"KERNEL_POD_NAME": "{{ missing_var }}",
60+
"KERNEL_NAMESPACE": "production",
61+
}
62+
}
63+
64+
with patch.object(self.proxy, 'log'), patch(
65+
'enterprise_gateway.services.processproxies.k8s.KernelSessionManager'
66+
) as mock_session_manager:
67+
mock_session_manager.get_kernel_username.return_value = "testuser"
68+
result = self.proxy._determine_kernel_pod_name(**kwargs)
69+
# Should fall back to default naming: kernel_username + "-" + kernel_id
70+
self.assertEqual(result, "testuser-test-kernel-id")
71+
72+
def test_malicious_template_injection_prevention(self):
73+
"""Test prevention of malicious template injection attacks."""
74+
malicious_templates = [
75+
# Python code execution attempts
76+
"{{ ''.__class__.__mro__[1].__subclasses__()[104].__init__.__globals__['sys'].exit() }}",
77+
"{{ __import__('os').system('rm -rf /') }}",
78+
"{{ exec('print(\"pwned\")') }}",
79+
"{{ eval('1+1') }}",
80+
# Attribute access attempts
81+
"{{ kernel_id.__class__ }}",
82+
"{{ kernel_id.__dict__ }}",
83+
"{{ kernel_id.__globals__ }}",
84+
# Function calls
85+
"{{ range(10) }}",
86+
"{{ len(kernel_id) }}",
87+
"{{ str.upper(kernel_id) }}",
88+
# Jinja2 filters and expressions
89+
"{{ kernel_id|upper }}",
90+
"{{ kernel_id + '_suffix' }}",
91+
"{{ 1 + 1 }}",
92+
# Complex expressions
93+
"{{ kernel_id if kernel_id else 'default' }}",
94+
"{{ kernel_id[:5] }}",
95+
]
96+
97+
variables = {"kernel_id": "test-123"}
98+
99+
for malicious_template in malicious_templates:
100+
with self.subTest(template=malicious_template), patch.object(
101+
self.proxy, 'log'
102+
) as mock_log:
103+
result = self.proxy._safe_template_substitute(malicious_template, variables)
104+
# All malicious templates should be treated as invalid and return None
105+
self.assertIsNone(result)
106+
mock_log.warning.assert_called_once()
107+
# Should warn about unsupported expressions
108+
self.assertIn("Invalid template syntax", mock_log.warning.call_args[0][0])
109+
110+
def test_pod_name_determination_with_templates(self):
111+
"""Test complete pod name determination with template processing."""
112+
kwargs = {
113+
"env": {
114+
"KERNEL_POD_NAME": "{{ kernel_namespace }}-{{ kernel_id }}",
115+
"KERNEL_NAMESPACE": "production",
116+
"KERNEL_IMAGE": "python:3.9",
117+
}
118+
}
119+
120+
with patch.object(self.proxy, 'log'):
121+
result = self.proxy._determine_kernel_pod_name(**kwargs)
122+
# Should get processed and DNS-normalized
123+
self.assertEqual(result, "production-test-kernel-id")
124+
125+
def test_pod_name_determination_with_malicious_template(self):
126+
"""Test pod name determination with malicious template falls back to default."""
127+
kwargs = {
128+
"env": {
129+
"KERNEL_POD_NAME": "{{ __import__('os').system('evil') }}",
130+
"KERNEL_NAMESPACE": "production",
131+
}
132+
}
133+
134+
with patch.object(self.proxy, 'log'), patch(
135+
'enterprise_gateway.services.processproxies.k8s.KernelSessionManager'
136+
) as mock_session_manager:
137+
mock_session_manager.get_kernel_username.return_value = "testuser"
138+
result = self.proxy._determine_kernel_pod_name(**kwargs)
139+
# Should fall back to default naming
140+
self.assertEqual(result, "testuser-test-kernel-id")
141+
142+
def test_pod_name_determination_with_missing_variables(self):
143+
"""Test pod name determination with missing variables falls back to default."""
144+
kwargs = {
145+
"env": {
146+
"KERNEL_POD_NAME": "{{ missing_var }}-{{ kernel_id }}",
147+
"KERNEL_NAMESPACE": "production",
148+
}
149+
}
150+
151+
with patch.object(self.proxy, 'log'), patch(
152+
'enterprise_gateway.services.processproxies.k8s.KernelSessionManager'
153+
) as mock_session_manager:
154+
mock_session_manager.get_kernel_username.return_value = "testuser"
155+
result = self.proxy._determine_kernel_pod_name(**kwargs)
156+
# Should fall back to default naming
157+
self.assertEqual(result, "testuser-test-kernel-id")
158+
159+
def test_pod_name_without_template(self):
160+
"""Test pod name determination without template syntax."""
161+
kwargs = {"env": {"KERNEL_POD_NAME": "static-pod-name", "KERNEL_NAMESPACE": "production"}}
162+
163+
with patch.object(self.proxy, 'log'):
164+
result = self.proxy._determine_kernel_pod_name(**kwargs)
165+
# Should use as-is and DNS-normalize
166+
self.assertEqual(result, "static-pod-name")
167+
168+
def test_pod_name_dns_normalization(self):
169+
"""Test DNS name normalization of pod names."""
170+
kwargs = {
171+
"env": {
172+
"KERNEL_POD_NAME": "{{ kernel_namespace }}_{{ kernel_id }}",
173+
"KERNEL_NAMESPACE": "Test-Namespace",
174+
"KERNEL_IMAGE": "python:3.9",
175+
}
176+
}
177+
178+
with patch.object(self.proxy, 'log'):
179+
result = self.proxy._determine_kernel_pod_name(**kwargs)
180+
# Should be DNS-normalized (lowercase, dashes only)
181+
self.assertEqual(result, "test-namespace-test-kernel-id")
182+
183+
def test_regex_pattern_validation(self):
184+
"""Test that only valid variable names are matched by regex."""
185+
valid_vars = [
186+
"kernel_id",
187+
"kernel_namespace",
188+
"kernel_image_pull_policy",
189+
"a",
190+
"var123",
191+
"KERNEL_ID",
192+
]
193+
194+
# Variables that should be blocked by the regex pattern
195+
invalid_vars = [
196+
"123invalid", # starts with number
197+
"invalid-var", # contains dash
198+
"invalid.var", # contains dot
199+
"invalid var", # contains space
200+
"invalid@var", # contains special char
201+
"_private_var", # starts with underscore (security risk)
202+
"__class__", # magic method (security risk)
203+
"__dict__", # magic method (security risk)
204+
"__globals__", # magic method (security risk)
205+
]
206+
207+
variables = {var: "value" for var in valid_vars}
208+
# Also add underscore variables to test they're not substituted even if present
209+
variables.update(
210+
{"_private_var": "private", "__class__": "dangerous", "__dict__": "dangerous"}
211+
)
212+
213+
# Valid variables should be substituted
214+
for var in valid_vars:
215+
template = f"{{{{ {var} }}}}"
216+
result = self.proxy._safe_template_substitute(template, variables)
217+
self.assertEqual(result, "value", f"Valid variable {var} should be substituted")
218+
219+
# Invalid variables should be treated as having invalid syntax
220+
for var in invalid_vars:
221+
template = f"{{{{ {var} }}}}"
222+
with patch.object(self.proxy, 'log') as mock_log:
223+
result = self.proxy._safe_template_substitute(template, variables)
224+
self.assertIsNone(result, f"Invalid variable {var} should be rejected")
225+
mock_log.warning.assert_called_once()
226+
# Should warn about unsupported expressions since invalid var names don't match regex
227+
self.assertIn("Invalid template syntax", mock_log.warning.call_args[0][0])
228+
229+
230+
if __name__ == '__main__':
231+
unittest.main()

0 commit comments

Comments
 (0)