33#include  " chat.h" 
44#include  " chatmodel.h" 
55#include  " modellist.h" 
6+ #include  " mwhttpserver.h" 
67#include  " mysettings.h" 
78#include  " utils.h"   //  IWYU pragma: keep
89
910#include  < fmt/format.h> 
1011#include  < gpt4all-backend/llmodel.h> 
1112
13+ #include  < QAbstractSocket> 
1214#include  < QByteArray> 
1315#include  < QCborArray> 
1416#include  < QCborMap> 
5153
5254using  namespace  std ::string_literals; 
5355using  namespace  Qt ::Literals::StringLiterals; 
56+ using  namespace  gpt4all ::ui; 
5457
5558// #define DEBUG
5659
@@ -443,6 +446,8 @@ Server::Server(Chat *chat)
443446    connect (chat, &Chat::collectionListChanged, this , &Server::handleCollectionListChanged, Qt::QueuedConnection);
444447}
445448
449+ Server::~Server () = default ;
450+ 
446451static  QJsonObject requestFromJson (const  QByteArray &request)
447452{
448453    QJsonParseError err;
@@ -455,17 +460,57 @@ static QJsonObject requestFromJson(const QByteArray &request)
455460    return  document.object ();
456461}
457462
463+ // / @brief Check if a host is safe to use to connect to the server.
464+ // /
465+ // / GPT4All's local server is not safe to expose to the internet, as it does not provide
466+ // / any form of authentication. DNS rebind attacks bypass CORS and without additional host
467+ // / header validation, malicious websites can access the server in client-side js.
468+ // /
469+ // / @param host The value of the "Host" header or ":authority" pseudo-header
470+ // / @return true if the host is unsafe, false otherwise
471+ static  bool  isHostUnsafe (const  QString &host)
472+ {
473+     QHostAddress addr;
474+     if  (addr.setAddress (host) && addr.protocol () == QAbstractSocket::IPv4Protocol)
475+         return  false ; //  ipv4
476+ 
477+     //  ipv6 host is wrapped in square brackets
478+     static  const  QRegularExpression ipv6Re (uR"( ^\[(.+)\]$)"  _s);
479+     if  (auto  match = ipv6Re.match (host); match.hasMatch ()) {
480+         auto  ipv6 = match.captured (1 );
481+         if  (addr.setAddress (ipv6) && addr.protocol () == QAbstractSocket::IPv6Protocol)
482+             return  false ; //  ipv6
483+     }
484+ 
485+     if  (!host.contains (' .'  ))
486+         return  false ; //  dotless hostname
487+ 
488+     static  const  QStringList allowedTlds { u" .local"  _s, u" .test"  _s, u" .internal"  _s };
489+     for  (auto  &tld : allowedTlds)
490+         if  (host.endsWith (tld, Qt::CaseInsensitive))
491+             return  false ; //  local TLD
492+ 
493+     return  true ; //  unsafe
494+ }
495+ 
458496void  Server::start ()
459497{
460-     m_server = std::make_unique<QHttpServer>(this );
461-     auto  *tcpServer = new  QTcpServer (m_server.get ());
498+     m_server = std::make_unique<MwHttpServer>();
499+ 
500+     m_server->addBeforeRequestHandler ([](const  QHttpServerRequest &req) -> std::optional<QHttpServerResponse> {
501+         //  this works for HTTP/1.1 "Host" header and HTTP/2 ":authority" pseudo-header
502+         auto  host = req.url ().host ();
503+         if  (!host.isEmpty () && isHostUnsafe (host))
504+             return  QHttpServerResponse (QHttpServerResponder::StatusCode::Forbidden);
505+         return  std::nullopt ;
506+     });
462507
463508    auto  port = MySettings::globalInstance ()->networkPort ();
464-     if  (!tcpServer->listen (QHostAddress::LocalHost, port)) {
509+     if  (!m_server-> tcpServer () ->listen (QHostAddress::LocalHost, port)) {
465510        qWarning () << " Server ERROR: Failed to listen on port"   << port;
466511        return ;
467512    }
468-     if  (!m_server->bind (tcpServer )) {
513+     if  (!m_server->bind ()) {
469514        qWarning () << " Server ERROR: Failed to HTTP server to socket"   << port;
470515        return ;
471516    }
@@ -490,7 +535,7 @@ void Server::start()
490535        }
491536    );
492537
493-     m_server->route (" /v1/models/<arg>"  , QHttpServerRequest::Method::Get,
538+     m_server->route < const  QString &> (" /v1/models/<arg>"  , QHttpServerRequest::Method::Get,
494539        [](const  QString &model, const  QHttpServerRequest &) {
495540            if  (!MySettings::globalInstance ()->serverChat ())
496541                return  QHttpServerResponse (QHttpServerResponder::StatusCode::Unauthorized);
@@ -562,7 +607,7 @@ void Server::start()
562607
563608    //  Respond with code 405 to wrong HTTP methods:
564609    m_server->route (" /v1/models"  ,  QHttpServerRequest::Method::Post,
565-         [] {
610+         []( const  QHttpServerRequest &)  {
566611            if  (!MySettings::globalInstance ()->serverChat ())
567612                return  QHttpServerResponse (QHttpServerResponder::StatusCode::Unauthorized);
568613            return  QHttpServerResponse (
@@ -573,8 +618,8 @@ void Server::start()
573618        }
574619    );
575620
576-     m_server->route (" /v1/models/<arg>"  , QHttpServerRequest::Method::Post,
577-         [](const  QString &model) {
621+     m_server->route < const  QString &> (" /v1/models/<arg>"  , QHttpServerRequest::Method::Post,
622+         [](const  QString &model,  const  QHttpServerRequest & ) {
578623            (void )model;
579624            if  (!MySettings::globalInstance ()->serverChat ())
580625                return  QHttpServerResponse (QHttpServerResponder::StatusCode::Unauthorized);
@@ -587,7 +632,7 @@ void Server::start()
587632    );
588633
589634    m_server->route (" /v1/completions"  , QHttpServerRequest::Method::Get,
590-         [] {
635+         []( const  QHttpServerRequest &)  {
591636            if  (!MySettings::globalInstance ()->serverChat ())
592637                return  QHttpServerResponse (QHttpServerResponder::StatusCode::Unauthorized);
593638            return  QHttpServerResponse (
@@ -598,7 +643,7 @@ void Server::start()
598643    );
599644
600645    m_server->route (" /v1/chat/completions"  , QHttpServerRequest::Method::Get,
601-         [] {
646+         []( const  QHttpServerRequest &)  {
602647            if  (!MySettings::globalInstance ()->serverChat ())
603648                return  QHttpServerResponse (QHttpServerResponder::StatusCode::Unauthorized);
604649            return  QHttpServerResponse (
0 commit comments