diff --git a/mlir/include/mlir/Tools/lsp-server-support/Transport.h b/mlir/include/mlir/Tools/lsp-server-support/Transport.h index 6843bc76ab9dc..0010a475fedd2 100644 --- a/mlir/include/mlir/Tools/lsp-server-support/Transport.h +++ b/mlir/include/mlir/Tools/lsp-server-support/Transport.h @@ -43,14 +43,59 @@ enum JSONStreamStyle { Delimited }; +/// An abstract class used by the JSONTransport to read JSON message. +class JSONTransportInput { +public: + explicit JSONTransportInput(JSONStreamStyle style = JSONStreamStyle::Standard) + : style(style) {} + virtual ~JSONTransportInput() = default; + + virtual bool hasError() const = 0; + virtual bool isEndOfInput() const = 0; + + /// Read in a message from the input stream. + LogicalResult readMessage(std::string &json) { + return style == JSONStreamStyle::Delimited ? readDelimitedMessage(json) + : readStandardMessage(json); + } + virtual LogicalResult readDelimitedMessage(std::string &json) = 0; + virtual LogicalResult readStandardMessage(std::string &json) = 0; + +private: + /// The JSON stream style to use. + JSONStreamStyle style; +}; + +/// Concrete implementation of the JSONTransportInput that reads from a file. +class JSONTransportInputOverFile : public JSONTransportInput { +public: + explicit JSONTransportInputOverFile( + std::FILE *in, JSONStreamStyle style = JSONStreamStyle::Standard) + : JSONTransportInput(style), in(in) {} + + bool hasError() const final { return ferror(in); } + bool isEndOfInput() const final { return feof(in); } + + LogicalResult readDelimitedMessage(std::string &json) final; + LogicalResult readStandardMessage(std::string &json) final; + +private: + std::FILE *in; +}; + /// A transport class that performs the JSON-RPC communication with the LSP /// client. class JSONTransport { public: + JSONTransport(std::unique_ptr in, raw_ostream &out, + bool prettyOutput = false) + : in(std::move(in)), out(out), prettyOutput(prettyOutput) {} + JSONTransport(std::FILE *in, raw_ostream &out, JSONStreamStyle style = JSONStreamStyle::Standard, bool prettyOutput = false) - : in(in), out(out), style(style), prettyOutput(prettyOutput) {} + : in(std::make_unique(in, style)), out(out), + prettyOutput(prettyOutput) {} /// The following methods are used to send a message to the LSP client. void notify(StringRef method, llvm::json::Value params); @@ -66,22 +111,12 @@ class JSONTransport { /// Writes the given message to the output stream. void sendMessage(llvm::json::Value msg); - /// Read in a message from the input stream. - LogicalResult readMessage(std::string &json) { - return style == JSONStreamStyle::Delimited ? readDelimitedMessage(json) - : readStandardMessage(json); - } - LogicalResult readDelimitedMessage(std::string &json); - LogicalResult readStandardMessage(std::string &json); - - /// An output buffer used when building output messages. +private: + /// The input to read a message from. + std::unique_ptr in; SmallVector outputBuffer; - /// The input file stream. - std::FILE *in; /// The output file stream. raw_ostream &out; - /// The JSON stream style to use. - JSONStreamStyle style; /// If the output JSON should be formatted for easier readability. bool prettyOutput; }; diff --git a/mlir/lib/Tools/lsp-server-support/Transport.cpp b/mlir/lib/Tools/lsp-server-support/Transport.cpp index ad8308f69aead..d0863ba0ae087 100644 --- a/mlir/lib/Tools/lsp-server-support/Transport.cpp +++ b/mlir/lib/Tools/lsp-server-support/Transport.cpp @@ -209,13 +209,13 @@ void JSONTransport::reply(llvm::json::Value id, llvm::Error JSONTransport::run(MessageHandler &handler) { std::string json; - while (!feof(in)) { - if (ferror(in)) { + while (!in->isEndOfInput()) { + if (in->hasError()) { return llvm::errorCodeToError( std::error_code(errno, std::system_category())); } - if (succeeded(readMessage(json))) { + if (succeeded(in->readMessage(json))) { if (llvm::Expected doc = llvm::json::parse(json)) { if (!handleMessage(std::move(*doc), handler)) return llvm::Error::success(); @@ -303,13 +303,14 @@ LogicalResult readLine(std::FILE *in, SmallVectorImpl &out) { // Returns std::nullopt when: // - ferror(), feof(), or shutdownRequested() are set. // - Content-Length is missing or empty (protocol error) -LogicalResult JSONTransport::readStandardMessage(std::string &json) { +LogicalResult +JSONTransportInputOverFile::readStandardMessage(std::string &json) { // A Language Server Protocol message starts with a set of HTTP headers, // delimited by \r\n, and terminated by an empty line (\r\n). unsigned long long contentLength = 0; llvm::SmallString<128> line; while (true) { - if (feof(in) || ferror(in) || failed(readLine(in, line))) + if (feof(in) || hasError() || failed(readLine(in, line))) return failure(); // Content-Length is a mandatory header, and the only one we handle. @@ -349,7 +350,8 @@ LogicalResult JSONTransport::readStandardMessage(std::string &json) { /// This is a testing path, so favor simplicity over performance here. /// When returning failure: feof(), ferror(), or shutdownRequested() will be /// set. -LogicalResult JSONTransport::readDelimitedMessage(std::string &json) { +LogicalResult +JSONTransportInputOverFile::readDelimitedMessage(std::string &json) { json.clear(); llvm::SmallString<128> line; while (succeeded(readLine(in, line))) {