Skip to content

Commit c399f93

Browse files
authored
[develop] Upgrade connexion to ~=2.15.1, upgrade Werkzeug to ~=3.1 to address CVE-2024-34069 (#6932)
* Upgrade Connexion to ~=2.15.1 (from ~=2.13.0). * Upgrade Flask to ~=3.0 (from >=2.2.5,<2.3). * Upgrade Werkzeug to ~=3.1 (from ~=2.0) to address [CVE-2024-34069](https://nvd.nist.gov/vuln/detail/cve-2024-34069). * Upgrade serverless_wsgi.py to the latest version. * Changes to encoder.py, flask_app.py and etc. to adapt the version bump.
1 parent 9736719 commit c399f93

File tree

7 files changed

+87
-39
lines changed

7 files changed

+87
-39
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@ CHANGELOG
66

77
**CHANGES**
88
- Add validator that warns against the downsides of disabling in-place updates on compute and login nodes through DevSettings.
9+
- Upgrade Connexion to ~=2.15.1 (from ~=2.13.0).
10+
- Upgrade Flask to ~=3.0 (from >=2.2.5,<2.3).
11+
- Upgrade Werkzeug to ~=3.1 (from ~=2.0) to address [CVE-2024-34069](https://nvd.nist.gov/vuln/detail/cve-2024-34069).
912

1013
**BUG FIXES**
1114
- Reduce EFA installation time for Ubuntu by ~20 minutes by only holding kernel packages for the installed kernel.

cli/requirements.txt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,13 @@ aws-cdk.core~=1.164
1616
aws_cdk.aws-cloudwatch~=1.164
1717
aws_cdk.aws-lambda~=1.164
1818
boto3>=1.39.4
19-
connexion~=2.13.0
20-
flask>=2.2.5,<2.3
2119
jinja2~=3.0
2220
jmespath~=0.10
2321
jsii==1.85.0
2422
marshmallow~=3.10
2523
PyYAML>=5.3.1,!=5.4
2624
tabulate>=0.8.8,<=0.8.10
27-
werkzeug~=2.0
25+
connexion~=2.15.1
26+
werkzeug~=3.1
27+
flask~=3.0
2828
packaging~=25.0

cli/setup.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,11 +47,12 @@ def readme():
4747
"aws-cdk.aws-ssm~=" + CDK_VERSION,
4848
"aws-cdk.aws-sqs~=" + CDK_VERSION,
4949
"aws-cdk.aws-cloudformation~=" + CDK_VERSION,
50-
"werkzeug~=2.0",
51-
"connexion~=2.13.0",
52-
"flask>=2.2.5,<2.3",
50+
"connexion~=2.15.1",
5351
"jmespath~=0.10",
5452
"jsii==1.85.0",
53+
"werkzeug~=3.1",
54+
"flask~=3.0",
55+
"packaging~=25.0",
5556
]
5657

5758
LAMBDA_REQUIRES = [

cli/src/pcluster/api/awslambda/serverless_wsgi.py

Lines changed: 46 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@
1717
import json
1818
import os
1919
import sys
20+
from urllib.parse import unquote, unquote_plus, urlencode
2021

21-
from werkzeug.datastructures import Headers, MultiDict, iter_multi_items
22+
from werkzeug.datastructures import Headers, iter_multi_items
2223
from werkzeug.http import HTTP_STATUS_CODES
23-
from werkzeug.urls import url_encode, url_unquote, url_unquote_plus
2424
from werkzeug.wrappers import Response
2525

2626
# List of MIME types that should not be base64 encoded. MIME types within `text/*`
@@ -95,8 +95,8 @@ def encode_query_string(event):
9595
if not params:
9696
params = ""
9797
if is_alb_event(event):
98-
params = MultiDict((url_unquote_plus(k), url_unquote_plus(v)) for k, v in iter_multi_items(params))
99-
return url_encode(params)
98+
params = [(unquote_plus(k), unquote_plus(v)) for k, v in iter_multi_items(params)]
99+
return urlencode(params, doseq=True)
100100

101101

102102
def get_script_name(headers, request_context):
@@ -108,7 +108,7 @@ def get_script_name(headers, request_context):
108108
"1",
109109
]
110110

111-
if headers.get("Host", "").endswith(".amazonaws.com") and not strip_stage_path:
111+
if "amazonaws.com" in headers.get("Host", "") and not strip_stage_path:
112112
script_name = "/{}".format(request_context.get("stage", ""))
113113
else:
114114
script_name = ""
@@ -138,7 +138,7 @@ def setup_environ_items(environ, headers):
138138
def generate_response(response, event):
139139
returndict = {"statusCode": response.status_code}
140140

141-
if "multiValueHeaders" in event:
141+
if "multiValueHeaders" in event and event["multiValueHeaders"]:
142142
returndict["multiValueHeaders"] = group_headers(response.headers)
143143
else:
144144
returndict["headers"] = split_headers(response.headers)
@@ -164,12 +164,27 @@ def generate_response(response, event):
164164
return returndict
165165

166166

167+
def strip_express_gateway_query_params(path):
168+
"""Contrary to regular AWS lambda HTTP events, Express Gateway
169+
(https://github.com/ExpressGateway/express-gateway-plugin-lambda)
170+
adds query parameters to the path, which we need to strip.
171+
"""
172+
if "?" in path:
173+
path = path.split("?")[0]
174+
return path
175+
176+
167177
def handle_request(app, event, context):
168178
if event.get("source") in ["aws.events", "serverless-plugin-warmup"]:
169179
print("Lambda warming event received, skipping handler")
170180
return {}
171181

172-
if event.get("version") is None and event.get("isBase64Encoded") is None and not is_alb_event(event):
182+
if (
183+
event.get("version") is None
184+
and event.get("isBase64Encoded") is None
185+
and event.get("requestPath") is not None
186+
and not is_alb_event(event)
187+
):
173188
return handle_lambda_integration(app, event, context)
174189

175190
if event.get("version") == "2.0":
@@ -179,7 +194,7 @@ def handle_request(app, event, context):
179194

180195

181196
def handle_payload_v1(app, event, context):
182-
if "multiValueHeaders" in event:
197+
if "multiValueHeaders" in event and event["multiValueHeaders"]:
183198
headers = Headers(event["multiValueHeaders"])
184199
else:
185200
headers = Headers(event["headers"])
@@ -189,35 +204,35 @@ def handle_payload_v1(app, event, context):
189204
# If a user is using a custom domain on API Gateway, they may have a base
190205
# path in their URL. This allows us to strip it out via an optional
191206
# environment variable.
192-
path_info = event["path"]
207+
path_info = strip_express_gateway_query_params(event["path"])
193208
base_path = os.environ.get("API_GATEWAY_BASE_PATH")
194209
if base_path:
195210
script_name = "/" + base_path
196211

197212
if path_info.startswith(script_name):
198213
path_info = path_info[len(script_name) :] # noqa: E203
199214

200-
body = event["body"] or ""
215+
body = event.get("body") or ""
201216
body = get_body_bytes(event, body)
202217

203218
environ = {
204219
"CONTENT_LENGTH": str(len(body)),
205220
"CONTENT_TYPE": headers.get("Content-Type", ""),
206-
"PATH_INFO": url_unquote(path_info),
221+
"PATH_INFO": unquote(path_info),
207222
"QUERY_STRING": encode_query_string(event),
208223
"REMOTE_ADDR": event.get("requestContext", {}).get("identity", {}).get("sourceIp", ""),
209-
"REMOTE_USER": event.get("requestContext", {}).get("authorizer", {}).get("principalId", ""),
224+
"REMOTE_USER": (event.get("requestContext", {}).get("authorizer") or {}).get("principalId", ""),
210225
"REQUEST_METHOD": event.get("httpMethod", {}),
211226
"SCRIPT_NAME": script_name,
212227
"SERVER_NAME": headers.get("Host", "lambda"),
213-
"SERVER_PORT": headers.get("X-Forwarded-Port", "80"),
228+
"SERVER_PORT": headers.get("X-Forwarded-Port", "443"),
214229
"SERVER_PROTOCOL": "HTTP/1.1",
215230
"wsgi.errors": sys.stderr,
216231
"wsgi.input": io.BytesIO(body),
217232
"wsgi.multiprocess": False,
218233
"wsgi.multithread": False,
219234
"wsgi.run_once": False,
220-
"wsgi.url_scheme": headers.get("X-Forwarded-Proto", "http"),
235+
"wsgi.url_scheme": headers.get("X-Forwarded-Proto", "https"),
221236
"wsgi.version": (1, 0),
222237
"serverless.authorizer": event.get("requestContext", {}).get("authorizer"),
223238
"serverless.event": event,
@@ -237,31 +252,37 @@ def handle_payload_v2(app, event, context):
237252

238253
script_name = get_script_name(headers, event.get("requestContext", {}))
239254

240-
path_info = event["rawPath"]
255+
path_info = strip_express_gateway_query_params(event["rawPath"])
256+
base_path = os.environ.get("API_GATEWAY_BASE_PATH")
257+
if base_path:
258+
script_name = "/" + base_path
259+
260+
if path_info.startswith(script_name):
261+
path_info = path_info[len(script_name) :] # noqa: E203
241262

242263
body = event.get("body", "")
243264
body = get_body_bytes(event, body)
244265

245266
headers["Cookie"] = "; ".join(event.get("cookies", []))
246267

247268
environ = {
248-
"CONTENT_LENGTH": str(len(body)),
269+
"CONTENT_LENGTH": str(len(body or "")),
249270
"CONTENT_TYPE": headers.get("Content-Type", ""),
250-
"PATH_INFO": url_unquote(path_info),
271+
"PATH_INFO": unquote(path_info),
251272
"QUERY_STRING": event.get("rawQueryString", ""),
252273
"REMOTE_ADDR": event.get("requestContext", {}).get("http", {}).get("sourceIp", ""),
253274
"REMOTE_USER": event.get("requestContext", {}).get("authorizer", {}).get("principalId", ""),
254275
"REQUEST_METHOD": event.get("requestContext", {}).get("http", {}).get("method", ""),
255276
"SCRIPT_NAME": script_name,
256277
"SERVER_NAME": headers.get("Host", "lambda"),
257-
"SERVER_PORT": headers.get("X-Forwarded-Port", "80"),
278+
"SERVER_PORT": headers.get("X-Forwarded-Port", "443"),
258279
"SERVER_PROTOCOL": "HTTP/1.1",
259280
"wsgi.errors": sys.stderr,
260281
"wsgi.input": io.BytesIO(body),
261282
"wsgi.multiprocess": False,
262283
"wsgi.multithread": False,
263284
"wsgi.run_once": False,
264-
"wsgi.url_scheme": headers.get("X-Forwarded-Proto", "http"),
285+
"wsgi.url_scheme": headers.get("X-Forwarded-Proto", "https"),
265286
"wsgi.version": (1, 0),
266287
"serverless.authorizer": event.get("requestContext", {}).get("authorizer"),
267288
"serverless.event": event,
@@ -282,7 +303,7 @@ def handle_lambda_integration(app, event, context):
282303

283304
script_name = get_script_name(headers, event)
284305

285-
path_info = event["requestPath"]
306+
path_info = strip_express_gateway_query_params(event["requestPath"])
286307

287308
for key, value in event.get("path", {}).items():
288309
path_info = path_info.replace("{%s}" % key, value)
@@ -293,23 +314,23 @@ def handle_lambda_integration(app, event, context):
293314
body = get_body_bytes(event, body)
294315

295316
environ = {
296-
"CONTENT_LENGTH": str(len(body)),
317+
"CONTENT_LENGTH": str(len(body or "")),
297318
"CONTENT_TYPE": headers.get("Content-Type", ""),
298-
"PATH_INFO": url_unquote(path_info),
299-
"QUERY_STRING": url_encode(event.get("query", {})),
319+
"PATH_INFO": unquote(path_info),
320+
"QUERY_STRING": urlencode(event.get("query", {}), doseq=True),
300321
"REMOTE_ADDR": event.get("identity", {}).get("sourceIp", ""),
301322
"REMOTE_USER": event.get("principalId", ""),
302323
"REQUEST_METHOD": event.get("method", ""),
303324
"SCRIPT_NAME": script_name,
304325
"SERVER_NAME": headers.get("Host", "lambda"),
305-
"SERVER_PORT": headers.get("X-Forwarded-Port", "80"),
326+
"SERVER_PORT": headers.get("X-Forwarded-Port", "443"),
306327
"SERVER_PROTOCOL": "HTTP/1.1",
307328
"wsgi.errors": sys.stderr,
308329
"wsgi.input": io.BytesIO(body),
309330
"wsgi.multiprocess": False,
310331
"wsgi.multithread": False,
311332
"wsgi.run_once": False,
312-
"wsgi.url_scheme": headers.get("X-Forwarded-Proto", "http"),
333+
"wsgi.url_scheme": headers.get("X-Forwarded-Proto", "https"),
313334
"wsgi.version": (1, 0),
314335
"serverless.authorizer": event.get("enhancedAuthContext"),
315336
"serverless.event": event,

cli/src/pcluster/api/encoder.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,16 @@
99
# Generated by OpenAPI Generator (python-flask)
1010

1111
import datetime
12+
import json
1213

1314
import six
14-
from connexion.apps.flask_app import FlaskJSONEncoder
15+
from flask.json.provider import DefaultJSONProvider
1516

1617
from pcluster.api.models.base_model_ import Model
1718
from pcluster.utils import to_iso_timestr
1819

1920

20-
class JSONEncoder(FlaskJSONEncoder):
21+
class JSONEncoder(json.JSONEncoder):
2122
"""Make the model objects JSON serializable."""
2223

2324
include_nulls = False
@@ -35,4 +36,25 @@ def default(self, obj): # pylint: disable=arguments-renamed
3536
return dikt
3637
elif isinstance(obj, datetime.date):
3738
return to_iso_timestr(obj)
38-
return FlaskJSONEncoder.default(self, obj)
39+
return json.JSONEncoder.default(self, obj)
40+
41+
42+
class FlaskJSONEncoder(DefaultJSONProvider):
43+
"""Make the model objects JSON serializable."""
44+
45+
include_nulls = False
46+
47+
def default(self, obj): # pylint: disable=arguments-renamed
48+
"""Override the base method to add support for model objects serialization."""
49+
if isinstance(obj, Model):
50+
dikt = {}
51+
for attr, _ in six.iteritems(obj.openapi_types):
52+
value = getattr(obj, attr)
53+
if value is None and not self.include_nulls:
54+
continue
55+
attr = obj.attribute_map[attr]
56+
dikt[attr] = value
57+
return dikt
58+
elif isinstance(obj, datetime.date):
59+
return to_iso_timestr(obj)
60+
return super().default(obj)

cli/src/pcluster/api/errors.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
# OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions and
77
# limitations under the License.
88

9-
from connexion import ProblemException
9+
from connexion.exceptions import ProblemException
1010
from werkzeug.exceptions import HTTPException
1111

1212
from pcluster.api.models import (

cli/src/pcluster/api/flask_app.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@
88
import functools
99
import logging
1010

11-
import connexion
12-
from connexion import ProblemException
11+
from connexion.apps.flask_app import FlaskApp
1312
from connexion.decorators.validation import ParameterValidator
13+
from connexion.exceptions import ProblemException
1414
from flask import Response, jsonify, request
1515
from werkzeug.exceptions import HTTPException
1616

@@ -74,9 +74,10 @@ def __init__(self, swagger_ui: bool = False, validate_responses=False):
7474
assert_valid_node_js()
7575
options = {"swagger_ui": swagger_ui}
7676

77-
self.app = connexion.FlaskApp(__name__, specification_dir="openapi/", skip_error_handlers=True)
77+
self.app = FlaskApp(__name__, specification_dir="openapi/", skip_error_handlers=True)
7878
self.flask_app = self.app.app
79-
self.flask_app.json_encoder = encoder.JSONEncoder
79+
self.flask_app.json_provider_class = encoder.FlaskJSONEncoder
80+
self.flask_app.json = encoder.FlaskJSONEncoder(self.flask_app)
8081
self.app.add_api(
8182
"openapi.yaml",
8283
arguments={"title": "ParallelCluster"},

0 commit comments

Comments
 (0)