Skip to content

Commit 09915ab

Browse files
committed
compare all static ok hosts with port, add 127.0.0.1 and localhost to it
use strncmp rather than memcmp, one of the strings coul be smaller than the other
1 parent 6575598 commit 09915ab

File tree

1 file changed

+16
-17
lines changed

1 file changed

+16
-17
lines changed

supervisor/shared/web_workflow/web_workflow.c

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -350,41 +350,40 @@ static bool _endswith(const char *str, const char *suffix) {
350350
return strcmp(str + (strlen(str) - strlen(suffix)), suffix) == 0;
351351
}
352352

353-
const char *ok_hosts[] = {"code.circuitpython.org"};
353+
const char *ok_hosts[] = {
354+
"code.circuitpython.org",
355+
"127.0.0.1",
356+
"localhost",
357+
};
354358

355359
static bool _origin_ok(const char *origin) {
356360
const char *http = "http://";
357361
const char *local = ".local";
358362

359363
// note: redirected requests send an Origin of "null" and will be caught by this
360-
if (memcmp(origin, http, strlen(http)) != 0) {
364+
if (strncmp(origin, http, strlen(http)) != 0) {
361365
return false;
362366
}
363367
// These are prefix checks up to : so that any port works.
364368
const char *hostname = common_hal_mdns_server_get_hostname(&mdns);
365369
const char *end = origin + strlen(http) + strlen(hostname) + strlen(local);
366-
if (memcmp(origin + strlen(http), hostname, strlen(hostname)) == 0 &&
367-
memcmp(origin + strlen(http) + strlen(hostname), local, strlen(local)) == 0 &&
370+
if (strncmp(origin + strlen(http), hostname, strlen(hostname)) == 0 &&
371+
strncmp(origin + strlen(http) + strlen(hostname), local, strlen(local)) == 0 &&
368372
(end[0] == '\0' || end[0] == ':')) {
369373
return true;
370374
}
371375

372376
end = origin + strlen(http) + strlen(_our_ip_encoded);
373-
if (memcmp(origin + strlen(http), _our_ip_encoded, strlen(_our_ip_encoded)) == 0 &&
377+
if (strncmp(origin + strlen(http), _our_ip_encoded, strlen(_our_ip_encoded)) == 0 &&
374378
(end[0] == '\0' || end[0] == ':')) {
375379
return true;
376380
}
377381

378-
const char *localhost = "127.0.0.1";
379-
end = origin + strlen(http) + strlen(localhost);
380-
if (memcmp(origin + strlen(http), localhost, strlen(localhost)) == 0
381-
&& (end[0] == '\0' || end[0] == ':')) {
382-
return true;
383-
}
384-
385382
for (size_t i = 0; i < MP_ARRAY_SIZE(ok_hosts); i++) {
386-
// This checks exactly.
387-
if (strcmp(origin + strlen(http), ok_hosts[i]) == 0) {
383+
// Allows any port
384+
end = origin + strlen(http) + strlen(ok_hosts[i]);
385+
if (strncmp(origin + strlen(http), ok_hosts[i], strlen(ok_hosts[i])) == 0
386+
&& (end[0] == '\0' || end[0] == ':')) {
388387
return true;
389388
}
390389
}
@@ -911,7 +910,7 @@ static bool _reply(socketpool_socket_obj_t *socket, _request *request) {
911910
} else if (strlen(request->origin) > 0 && !_origin_ok(request->origin)) {
912911
ESP_LOGE(TAG, "bad origin %s", request->origin);
913912
_reply_forbidden(socket, request);
914-
} else if (memcmp(request->path, "/fs/", 4) == 0) {
913+
} else if (strncmp(request->path, "/fs/", 4) == 0) {
915914
if (strcasecmp(request->method, "OPTIONS") == 0) {
916915
// OPTIONS is sent for CORS preflight, unauthenticated
917916
_reply_access_control(socket, request);
@@ -1032,7 +1031,7 @@ static bool _reply(socketpool_socket_obj_t *socket, _request *request) {
10321031
}
10331032
}
10341033
}
1035-
} else if (memcmp(request->path, "/cp/", 4) == 0) {
1034+
} else if (strncmp(request->path, "/cp/", 4) == 0) {
10361035
const char *path = request->path + 3;
10371036
if (strcasecmp(request->method, "OPTIONS") == 0) {
10381037
// handle preflight requests to /cp/
@@ -1177,7 +1176,7 @@ static void _process_request(socketpool_socket_obj_t *socket, _request *request)
11771176
request->state = STATE_HEADER_KEY;
11781177
if (strcasecmp(request->header_key, "Authorization") == 0) {
11791178
const char *prefix = "Basic ";
1180-
request->authenticated = memcmp(request->header_value, prefix, strlen(prefix)) == 0 &&
1179+
request->authenticated = strncmp(request->header_value, prefix, strlen(prefix)) == 0 &&
11811180
strcmp(_api_password, request->header_value + strlen(prefix)) == 0;
11821181
} else if (strcasecmp(request->header_key, "Host") == 0) {
11831182
request->redirect = strcmp(request->header_value, "circuitpython.local") == 0;

0 commit comments

Comments
 (0)