Skip to content

Commit 6d89f8f

Browse files
committed
Add binary safety check to server
1 parent 36696f3 commit 6d89f8f

File tree

4 files changed

+77
-6
lines changed

4 files changed

+77
-6
lines changed

llamafile/server/client.cpp

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include "llamafile/server/server.h"
2525
#include "llamafile/server/time.h"
2626
#include "llamafile/server/tokenbucket.h"
27+
#include "llamafile/server/utils.h"
2728
#include "llamafile/server/worker.h"
2829
#include "llamafile/string.h"
2930
#include "llamafile/threadlocal.h"
@@ -478,7 +479,7 @@ Client::send_response_chunk(const std::string_view content)
478479

479480
// perform send system call
480481
ssize_t sent;
481-
if ((sent = writev(fd_, iov, 3)) != bytes) {
482+
if ((sent = safe_writev(fd_, iov, 3)) != bytes) {
482483
if (sent == -1 && errno != EAGAIN && errno != ECONNRESET)
483484
SLOG("writev failed %m");
484485
close_connection_ = true;
@@ -504,15 +505,34 @@ Client::send_response_finish()
504505
return send("0\r\n\r\n");
505506
}
506507

507-
// writes raw data to socket
508+
// writes any old data to socket
509+
//
510+
// unlike send() this won't fail if binary content is detected.
511+
bool
512+
Client::send_binary(const void* p, size_t n)
513+
{
514+
ssize_t sent;
515+
if ((sent = write(fd_, p, n)) != n) {
516+
if (sent == -1 && errno != EAGAIN && errno != ECONNRESET)
517+
SLOG("write failed %m");
518+
close_connection_ = true;
519+
return false;
520+
}
521+
return true;
522+
}
523+
524+
// writes non-binary data to socket
508525
//
509526
// consider using the higher level methods like send_error(),
510527
// send_response(), send_response_start(), etc.
511528
bool
512529
Client::send(const std::string_view s)
513530
{
531+
iovec iov[1];
514532
ssize_t sent;
515-
if ((sent = write(fd_, s.data(), s.size())) != s.size()) {
533+
iov[0].iov_base = (void*)s.data();
534+
iov[0].iov_len = s.size();
535+
if ((sent = safe_writev(fd_, iov, 1)) != s.size()) {
516536
if (sent == -1 && errno != EAGAIN && errno != ECONNRESET)
517537
SLOG("write failed %m");
518538
close_connection_ = true;
@@ -521,7 +541,7 @@ Client::send(const std::string_view s)
521541
return true;
522542
}
523543

524-
// writes two pieces of raw data to socket in single system call
544+
// writes two pieces of non-binary data to socket in single system call
525545
//
526546
// consider using the higher level methods like send_error(),
527547
// send_response(), send_response_start(), etc.
@@ -534,7 +554,7 @@ Client::send2(const std::string_view s1, const std::string_view s2)
534554
iov[0].iov_len = s1.size();
535555
iov[1].iov_base = (void*)s2.data();
536556
iov[1].iov_len = s2.size();
537-
if ((sent = writev(fd_, iov, 2)) != s1.size() + s2.size()) {
557+
if ((sent = safe_writev(fd_, iov, 2)) != s1.size() + s2.size()) {
538558
if (sent == -1 && errno != EAGAIN && errno != ECONNRESET)
539559
SLOG("writev failed %m");
540560
close_connection_ = true;
@@ -755,7 +775,7 @@ Client::dispatcher()
755775
close_connection_ = true;
756776
return false;
757777
}
758-
if (!send(std::string_view(buf, chunk))) {
778+
if (!send_binary(buf, chunk)) {
759779
close_connection_ = true;
760780
return false;
761781
}

llamafile/server/client.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ struct Client
8585
bool read_content() __wur;
8686
bool send_continue() __wur;
8787
bool send(const std::string_view) __wur;
88+
bool send_binary(const void*, size_t) __wur;
8889
void defer_cleanup(void (*)(void*), void*);
8990
bool send_error(int, const char* = nullptr);
9091
char* append_http_response_message(char*, int, const char* = nullptr);

llamafile/server/utils.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include <__fwd/string_view.h>
2121
#include <__fwd/vector.h>
2222
#include <optional>
23+
#include <sys/uio.h>
2324

2425
struct llama_model;
2526

@@ -28,6 +29,9 @@ namespace server {
2829

2930
class Atom;
3031

32+
ssize_t
33+
safe_writev(int, const iovec*, int);
34+
3135
bool
3236
atob(std::string_view, bool);
3337

llamafile/server/writev.cpp

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
// -*- mode:c++;indent-tabs-mode:nil;c-basic-offset:4;coding:utf-8 -*-
2+
// vi: set et ft=cpp ts=4 sts=4 sw=4 fenc=utf-8 :vi
3+
//
4+
// Copyright 2024 Mozilla Foundation
5+
//
6+
// Licensed under the Apache License, Version 2.0 (the "License");
7+
// you may not use this file except in compliance with the License.
8+
// You may obtain a copy of the License at
9+
//
10+
// http://www.apache.org/licenses/LICENSE-2.0
11+
//
12+
// Unless required by applicable law or agreed to in writing, software
13+
// distributed under the License is distributed on an "AS IS" BASIS,
14+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
// See the License for the specific language governing permissions and
16+
// limitations under the License.
17+
18+
#include "llamafile/server/log.h"
19+
#include "utils.h"
20+
#include <cerrno>
21+
#include <string_view>
22+
23+
namespace lf {
24+
namespace server {
25+
26+
ssize_t
27+
safe_writev(int fd, const iovec* iov, int iovcnt)
28+
{
29+
for (int i = 0; i < iovcnt; ++i) {
30+
bool has_binary = false;
31+
size_t n = iov[i].iov_len;
32+
unsigned char* p = (unsigned char*)iov[i].iov_base;
33+
for (size_t j = 0; j < n; ++j) {
34+
has_binary |= p[j] < 7;
35+
}
36+
if (has_binary) {
37+
SLOG("safe_writev() detected binary server is compromised");
38+
errno = EINVAL;
39+
return -1;
40+
}
41+
}
42+
return writev(fd, iov, iovcnt);
43+
}
44+
45+
} // namespace server
46+
} // namespace lf

0 commit comments

Comments
 (0)