diff --git a/mlir/include/mlir/Tools/lsp-server-support/Logging.h b/llvm/include/llvm/Support/LSP/Logging.h similarity index 55% rename from mlir/include/mlir/Tools/lsp-server-support/Logging.h rename to llvm/include/llvm/Support/LSP/Logging.h index 9b090d05f7fa4..fe65899b1d4ce 100644 --- a/mlir/include/mlir/Tools/lsp-server-support/Logging.h +++ b/llvm/include/llvm/Support/LSP/Logging.h @@ -1,4 +1,4 @@ -//===- Logging.h - MLIR LSP Server Logging ----------------------*- C++ -*-===// +//===- Logging.h - LSP Server Logging ----------------------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,16 +6,15 @@ // //===----------------------------------------------------------------------===// -#ifndef MLIR_TOOLS_LSPSERVERSUPPORT_LOGGING_H -#define MLIR_TOOLS_LSPSERVERSUPPORT_LOGGING_H +#ifndef LLVM_SUPPORT_LSP_LOGGING_H +#define LLVM_SUPPORT_LSP_LOGGING_H -#include "mlir/Support/LLVM.h" #include "llvm/Support/Debug.h" #include "llvm/Support/FormatVariadic.h" #include #include -namespace mlir { +namespace llvm { namespace lsp { /// This class represents the main interface for logging, and allows for @@ -26,21 +25,18 @@ class Logger { enum class Level { Debug, Info, Error }; /// Set the severity level of the logger. - static void setLogLevel(Level logLevel); + static void setLogLevel(Level LogLevel); /// Initiate a log message at various severity levels. These should be called /// after a call to `initialize`. - template - static void debug(const char *fmt, Ts &&...vals) { - log(Level::Debug, fmt, llvm::formatv(fmt, std::forward(vals)...)); + template static void debug(const char *Fmt, Ts &&...Vals) { + log(Level::Debug, Fmt, llvm::formatv(Fmt, std::forward(Vals)...)); } - template - static void info(const char *fmt, Ts &&...vals) { - log(Level::Info, fmt, llvm::formatv(fmt, std::forward(vals)...)); + template static void info(const char *Fmt, Ts &&...Vals) { + log(Level::Info, Fmt, llvm::formatv(Fmt, std::forward(Vals)...)); } - template - static void error(const char *fmt, Ts &&...vals) { - log(Level::Error, fmt, llvm::formatv(fmt, std::forward(vals)...)); + template static void error(const char *Fmt, Ts &&...Vals) { + log(Level::Error, Fmt, llvm::formatv(Fmt, std::forward(Vals)...)); } private: @@ -50,16 +46,16 @@ class Logger { static Logger &get(); /// Start a log message with the given severity level. - static void log(Level logLevel, const char *fmt, - const llvm::formatv_object_base &message); + static void log(Level LogLevel, const char *Fmt, + const llvm::formatv_object_base &Message); /// The minimum logging level. Messages with lower level are ignored. - Level logLevel = Level::Error; + Level LogLevel = Level::Error; /// A mutex used to guard logging. - std::mutex mutex; + std::mutex Mutex; }; } // namespace lsp -} // namespace mlir +} // namespace llvm -#endif // MLIR_TOOLS_LSPSERVERSUPPORT_LOGGING_H +#endif // LLVM_SUPPORT_LSP_LOGGING_H diff --git a/mlir/include/mlir/Tools/lsp-server-support/Protocol.h b/llvm/include/llvm/Support/LSP/Protocol.h similarity index 98% rename from mlir/include/mlir/Tools/lsp-server-support/Protocol.h rename to llvm/include/llvm/Support/LSP/Protocol.h index cc06dbfedb42a..93b82f1e581f8 100644 --- a/mlir/include/mlir/Tools/lsp-server-support/Protocol.h +++ b/llvm/include/llvm/Support/LSP/Protocol.h @@ -20,20 +20,24 @@ // //===----------------------------------------------------------------------===// -#ifndef MLIR_TOOLS_LSPSERVERSUPPORT_PROTOCOL_H -#define MLIR_TOOLS_LSPSERVERSUPPORT_PROTOCOL_H +#ifndef LLVM_SUPPORT_LSP_PROTOCOL_H +#define LLVM_SUPPORT_LSP_PROTOCOL_H -#include "mlir/Support/LLVM.h" #include "llvm/Support/JSON.h" +#include "llvm/Support/LogicalResult.h" #include "llvm/Support/SourceMgr.h" #include "llvm/Support/raw_ostream.h" #include #include #include #include -#include -namespace mlir { +// This file is using the LSP syntax for identifier names which is different +// from the LLVM coding standard. To avoid the clang-tidy warnings, we're +// disabling one check here. +// NOLINTBEGIN(readability-identifier-naming) + +namespace llvm { namespace lsp { enum class ErrorCode { @@ -1241,12 +1245,11 @@ struct CodeAction { llvm::json::Value toJSON(const CodeAction &); } // namespace lsp -} // namespace mlir +} // namespace llvm namespace llvm { -template <> -struct format_provider { - static void format(const mlir::lsp::Position &pos, raw_ostream &os, +template <> struct format_provider { + static void format(const llvm::lsp::Position &pos, raw_ostream &os, StringRef style) { assert(style.empty() && "style modifiers for this type are not supported"); os << pos; @@ -1255,3 +1258,5 @@ struct format_provider { } // namespace llvm #endif + +// NOLINTEND(readability-identifier-naming) diff --git a/llvm/include/llvm/Support/LSP/Transport.h b/llvm/include/llvm/Support/LSP/Transport.h new file mode 100644 index 0000000000000..ccd7f213aa277 --- /dev/null +++ b/llvm/include/llvm/Support/LSP/Transport.h @@ -0,0 +1,289 @@ +//===--- Transport.h - Sending and Receiving LSP messages -------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// The language server protocol is usually implemented by writing messages as +// JSON-RPC over the stdin/stdout of a subprocess. This file contains a JSON +// transport interface that handles this communication. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_SUPPORT_LSP_TRANSPORT_H +#define LLVM_SUPPORT_LSP_TRANSPORT_H + +#include "llvm/ADT/FunctionExtras.h" +#include "llvm/ADT/StringMap.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/FormatAdapters.h" +#include "llvm/Support/JSON.h" +#include "llvm/Support/LSP/Logging.h" +#include "llvm/Support/LSP/Protocol.h" +#include "llvm/Support/raw_ostream.h" +#include + +namespace llvm { +// Simple helper function that returns a string as printed from a op. +template static std::string debugString(T &&Op) { + std::string InstrStr; + llvm::raw_string_ostream Os(InstrStr); + Os << Op; + return Os.str(); +} +namespace lsp { +class MessageHandler; + +//===----------------------------------------------------------------------===// +// JSONTransport +//===----------------------------------------------------------------------===// + +/// The encoding style of the JSON-RPC messages (both input and output). +enum JSONStreamStyle { + /// Encoding per the LSP specification, with mandatory Content-Length header. + Standard, + /// Messages are delimited by a '// -----' line. Comment lines start with //. + Delimited +}; + +/// An abstract class used by the JSONTransport to read JSON message. +class JSONTransportInput { +public: + explicit JSONTransportInput(JSONStreamStyle Style = JSONStreamStyle::Standard) + : Style(Style) {} + virtual ~JSONTransportInput() = default; + + virtual bool hasError() const = 0; + virtual bool isEndOfInput() const = 0; + + /// Read in a message from the input stream. + LogicalResult readMessage(std::string &Json) { + return Style == JSONStreamStyle::Delimited ? readDelimitedMessage(Json) + : readStandardMessage(Json); + } + virtual LogicalResult readDelimitedMessage(std::string &Json) = 0; + virtual LogicalResult readStandardMessage(std::string &Json) = 0; + +private: + /// The JSON stream style to use. + JSONStreamStyle Style; +}; + +/// Concrete implementation of the JSONTransportInput that reads from a file. +class JSONTransportInputOverFile : public JSONTransportInput { +public: + explicit JSONTransportInputOverFile( + std::FILE *In, JSONStreamStyle Style = JSONStreamStyle::Standard) + : JSONTransportInput(Style), In(In) {} + + bool hasError() const final { return ferror(In); } + bool isEndOfInput() const final { return feof(In); } + + LogicalResult readDelimitedMessage(std::string &Json) final; + LogicalResult readStandardMessage(std::string &Json) final; + +private: + std::FILE *In; +}; + +/// A transport class that performs the JSON-RPC communication with the LSP +/// client. +class JSONTransport { +public: + JSONTransport(std::unique_ptr In, raw_ostream &Out, + bool PrettyOutput = false) + : In(std::move(In)), Out(Out), PrettyOutput(PrettyOutput) {} + + JSONTransport(std::FILE *In, raw_ostream &Out, + JSONStreamStyle Style = JSONStreamStyle::Standard, + bool PrettyOutput = false) + : In(std::make_unique(In, Style)), Out(Out), + PrettyOutput(PrettyOutput) {} + + /// The following methods are used to send a message to the LSP client. + void notify(StringRef Method, llvm::json::Value Params); + void call(StringRef Method, llvm::json::Value Params, llvm::json::Value Id); + void reply(llvm::json::Value Id, llvm::Expected Result); + + /// Start executing the JSON-RPC transport. + llvm::Error run(MessageHandler &Handler); + +private: + /// Dispatches the given incoming json message to the message handler. + bool handleMessage(llvm::json::Value Msg, MessageHandler &Handler); + /// Writes the given message to the output stream. + void sendMessage(llvm::json::Value Msg); + +private: + /// The input to read a message from. + std::unique_ptr In; + SmallVector OutputBuffer; + /// The output file stream. + raw_ostream &Out; + /// If the output JSON should be formatted for easier readability. + bool PrettyOutput; +}; + +//===----------------------------------------------------------------------===// +// MessageHandler +//===----------------------------------------------------------------------===// + +/// A Callback is a void function that accepts Expected. This is +/// accepted by functions that logically return T. +template +using Callback = llvm::unique_function)>; + +/// An OutgoingNotification is a function used for outgoing notifications +/// send to the client. +template +using OutgoingNotification = llvm::unique_function; + +/// An OutgoingRequest is a function used for outgoing requests to send to +/// the client. +template +using OutgoingRequest = + llvm::unique_function; + +/// An `OutgoingRequestCallback` is invoked when an outgoing request to the +/// client receives a response in turn. It is passed the original request's ID, +/// as well as the response result. +template +using OutgoingRequestCallback = + std::function)>; + +/// A handler used to process the incoming transport messages. +class MessageHandler { +public: + MessageHandler(JSONTransport &Transport) : Transport(Transport) {} + + bool onNotify(StringRef Method, llvm::json::Value Value); + bool onCall(StringRef Method, llvm::json::Value Params, llvm::json::Value Id); + bool onReply(llvm::json::Value Id, llvm::Expected Result); + + template + static llvm::Expected parse(const llvm::json::Value &Raw, + StringRef PayloadName, StringRef PayloadKind) { + T Result; + llvm::json::Path::Root Root; + if (fromJSON(Raw, Result, Root)) + return std::move(Result); + + // Dump the relevant parts of the broken message. + std::string Context; + llvm::raw_string_ostream Os(Context); + Root.printErrorContext(Raw, Os); + + // Report the error (e.g. to the client). + return llvm::make_error( + llvm::formatv("failed to decode {0} {1}: {2}", PayloadName, PayloadKind, + fmt_consume(Root.getError())), + ErrorCode::InvalidParams); + } + + template + void method(llvm::StringLiteral Method, ThisT *ThisPtr, + void (ThisT::*Handler)(const Param &, Callback)) { + MethodHandlers[Method] = [Method, Handler, + ThisPtr](llvm::json::Value RawParams, + Callback Reply) { + llvm::Expected Parameter = + parse(RawParams, Method, "request"); + if (!Parameter) + return Reply(Parameter.takeError()); + (ThisPtr->*Handler)(*Parameter, std::move(Reply)); + }; + } + + template + void notification(llvm::StringLiteral Method, ThisT *ThisPtr, + void (ThisT::*Handler)(const Param &)) { + NotificationHandlers[Method] = [Method, Handler, + ThisPtr](llvm::json::Value RawParams) { + llvm::Expected Parameter = + parse(RawParams, Method, "notification"); + if (!Parameter) { + return llvm::consumeError(llvm::handleErrors( + Parameter.takeError(), [](const LSPError &LspError) { + Logger::error("JSON parsing error: {0}", + LspError.message.c_str()); + })); + } + (ThisPtr->*Handler)(*Parameter); + }; + } + + /// Create an OutgoingNotification object used for the given method. + template + OutgoingNotification outgoingNotification(llvm::StringLiteral Method) { + return [&, Method](const T &Params) { + std::lock_guard TransportLock(TransportOutputMutex); + Logger::info("--> {0}", Method); + Transport.notify(Method, llvm::json::Value(Params)); + }; + } + + /// Create an OutgoingRequest function that, when called, sends a request with + /// the given method via the transport. Should the outgoing request be + /// met with a response, the result JSON is parsed and the response callback + /// is invoked. + template + OutgoingRequest + outgoingRequest(llvm::StringLiteral Method, + OutgoingRequestCallback Callback) { + return [&, Method, Callback](const Param &Parameter, llvm::json::Value Id) { + auto CallbackWrapper = [Method, Callback = std::move(Callback)]( + llvm::json::Value Id, + llvm::Expected Value) { + if (!Value) + return Callback(std::move(Id), Value.takeError()); + + std::string ResponseName = llvm::formatv("reply:{0}({1})", Method, Id); + llvm::Expected ParseResult = + parse(*Value, ResponseName, "response"); + if (!ParseResult) + return Callback(std::move(Id), ParseResult.takeError()); + + return Callback(std::move(Id), *ParseResult); + }; + + { + std::lock_guard Lock(ResponseHandlersMutex); + ResponseHandlers.insert( + {debugString(Id), std::make_pair(Method.str(), CallbackWrapper)}); + } + + std::lock_guard TransportLock(TransportOutputMutex); + Logger::info("--> {0}({1})", Method, Id); + Transport.call(Method, llvm::json::Value(Parameter), Id); + }; + } + +private: + template + using HandlerMap = llvm::StringMap>; + + HandlerMap NotificationHandlers; + HandlerMap)> + MethodHandlers; + + /// A pair of (1) the original request's method name, and (2) the callback + /// function to be invoked for responses. + using ResponseHandlerTy = + std::pair>; + /// A mapping from request/response ID to response handler. + llvm::StringMap ResponseHandlers; + /// Mutex to guard insertion into the response handler map. + std::mutex ResponseHandlersMutex; + + JSONTransport &Transport; + + /// Mutex to guard sending output messages to the transport. + std::mutex TransportOutputMutex; +}; + +} // namespace lsp +} // namespace llvm + +#endif diff --git a/llvm/lib/Support/CMakeLists.txt b/llvm/lib/Support/CMakeLists.txt index 2528e8bd1142a..7da972f372c5b 100644 --- a/llvm/lib/Support/CMakeLists.txt +++ b/llvm/lib/Support/CMakeLists.txt @@ -135,6 +135,7 @@ if (UNIX AND "${CMAKE_SYSTEM_NAME}" MATCHES "AIX") endif() add_subdirectory(BLAKE3) +add_subdirectory(LSP) add_llvm_component_library(LLVMSupport ABIBreak.cpp diff --git a/llvm/lib/Support/LSP/CMakeLists.txt b/llvm/lib/Support/LSP/CMakeLists.txt new file mode 100644 index 0000000000000..6094d9ac315c0 --- /dev/null +++ b/llvm/lib/Support/LSP/CMakeLists.txt @@ -0,0 +1,8 @@ +add_llvm_component_library(LLVMSupportLSP + Protocol.cpp + Transport.cpp + Logging.cpp + + DEPENDS + LLVMSupport +) diff --git a/mlir/lib/Tools/lsp-server-support/Logging.cpp b/llvm/lib/Support/LSP/Logging.cpp similarity index 55% rename from mlir/lib/Tools/lsp-server-support/Logging.cpp rename to llvm/lib/Support/LSP/Logging.cpp index 373e2165c244d..b36621ae1c6c6 100644 --- a/mlir/lib/Tools/lsp-server-support/Logging.cpp +++ b/llvm/lib/Support/LSP/Logging.cpp @@ -6,36 +6,36 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Tools/lsp-server-support/Logging.h" +#include "llvm/Support/LSP/Logging.h" #include "llvm/Support/Chrono.h" #include "llvm/Support/raw_ostream.h" -using namespace mlir; -using namespace mlir::lsp; +using namespace llvm; +using namespace llvm::lsp; -void Logger::setLogLevel(Level logLevel) { get().logLevel = logLevel; } +void Logger::setLogLevel(Level LogLevel) { get().LogLevel = LogLevel; } Logger &Logger::get() { - static Logger logger; - return logger; + static Logger Logger; + return Logger; } -void Logger::log(Level logLevel, const char *fmt, - const llvm::formatv_object_base &message) { - Logger &logger = get(); +void Logger::log(Level LogLevel, const char *Fmt, + const llvm::formatv_object_base &Message) { + Logger &Logger = get(); // Ignore messages with log levels below the current setting in the logger. - if (logLevel < logger.logLevel) + if (LogLevel < Logger.LogLevel) return; // An indicator character for each log level. - const char *logLevelIndicators = "DIE"; + const char *LogLevelIndicators = "DIE"; // Format the message and print to errs. - llvm::sys::TimePoint<> timestamp = std::chrono::system_clock::now(); - std::lock_guard logGuard(logger.mutex); + llvm::sys::TimePoint<> Timestamp = std::chrono::system_clock::now(); + std::lock_guard LogGuard(Logger.Mutex); llvm::errs() << llvm::formatv( "{0}[{1:%H:%M:%S.%L}] {2}\n", - logLevelIndicators[static_cast(logLevel)], timestamp, message); + LogLevelIndicators[static_cast(LogLevel)], Timestamp, Message); llvm::errs().flush(); } diff --git a/llvm/lib/Support/LSP/Protocol.cpp b/llvm/lib/Support/LSP/Protocol.cpp new file mode 100644 index 0000000000000..f22126345a435 --- /dev/null +++ b/llvm/lib/Support/LSP/Protocol.cpp @@ -0,0 +1,1043 @@ +//===--- Protocol.cpp - Language Server Protocol Implementation -----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains the serialization code for the LSP structs. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Support/LSP/Protocol.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/ADT/StringSet.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/JSON.h" +#include "llvm/Support/MemoryBuffer.h" +#include "llvm/Support/Path.h" +#include "llvm/Support/raw_ostream.h" + +using namespace llvm; +using namespace llvm::lsp; + +// Helper that doesn't treat `null` and absent fields as failures. +template +static bool mapOptOrNull(const llvm::json::Value &Params, + llvm::StringLiteral Prop, T &Out, + llvm::json::Path Path) { + const llvm::json::Object *O = Params.getAsObject(); + assert(O); + + // Field is missing or null. + auto *V = O->get(Prop); + if (!V || V->getAsNull()) + return true; + return fromJSON(*V, Out, Path.field(Prop)); +} + +//===----------------------------------------------------------------------===// +// LSPError +//===----------------------------------------------------------------------===// + +char LSPError::ID; + +//===----------------------------------------------------------------------===// +// URIForFile +//===----------------------------------------------------------------------===// + +static bool isWindowsPath(StringRef Path) { + return Path.size() > 1 && llvm::isAlpha(Path[0]) && Path[1] == ':'; +} + +static bool isNetworkPath(StringRef Path) { + return Path.size() > 2 && Path[0] == Path[1] && + llvm::sys::path::is_separator(Path[0]); +} + +static bool shouldEscapeInURI(unsigned char C) { + // Unreserved characters. + if ((C >= 'a' && C <= 'z') || (C >= 'A' && C <= 'Z') || + (C >= '0' && C <= '9')) + return false; + + switch (C) { + case '-': + case '_': + case '.': + case '~': + // '/' is only reserved when parsing. + case '/': + // ':' is only reserved for relative URI paths, which we doesn't produce. + case ':': + return false; + } + return true; +} + +/// Encodes a string according to percent-encoding. +/// - Unreserved characters are not escaped. +/// - Reserved characters always escaped with exceptions like '/'. +/// - All other characters are escaped. +static void percentEncode(StringRef Content, std::string &Out) { + for (unsigned char C : Content) { + if (shouldEscapeInURI(C)) { + Out.push_back('%'); + Out.push_back(llvm::hexdigit(C / 16)); + Out.push_back(llvm::hexdigit(C % 16)); + } else { + Out.push_back(C); + } + } +} + +/// Decodes a string according to percent-encoding. +static std::string percentDecode(StringRef Content) { + std::string Result; + for (auto I = Content.begin(), E = Content.end(); I != E; ++I) { + if (*I != '%') { + Result += *I; + continue; + } + if (*I == '%' && I + 2 < Content.end() && llvm::isHexDigit(*(I + 1)) && + llvm::isHexDigit(*(I + 2))) { + Result.push_back(llvm::hexFromNibbles(*(I + 1), *(I + 2))); + I += 2; + } else { + Result.push_back(*I); + } + } + return Result; +} + +/// Return the set containing the supported URI schemes. +static StringSet<> &getSupportedSchemes() { + static StringSet<> Schemes({"file", "test"}); + return Schemes; +} + +/// Returns true if the given scheme is structurally valid, i.e. it does not +/// contain any invalid scheme characters. This does not check that the scheme +/// is actually supported. +static bool isStructurallyValidScheme(StringRef Scheme) { + if (Scheme.empty()) + return false; + if (!llvm::isAlpha(Scheme[0])) + return false; + return llvm::all_of(llvm::drop_begin(Scheme), [](char C) { + return llvm::isAlnum(C) || C == '+' || C == '.' || C == '-'; + }); +} + +static llvm::Expected uriFromAbsolutePath(StringRef AbsolutePath, + StringRef Scheme) { + std::string Body; + StringRef Authority; + StringRef Root = llvm::sys::path::root_name(AbsolutePath); + if (isNetworkPath(Root)) { + // Windows UNC paths e.g. \\server\share => file://server/share + Authority = Root.drop_front(2); + AbsolutePath.consume_front(Root); + } else if (isWindowsPath(Root)) { + // Windows paths e.g. X:\path => file:///X:/path + Body = "/"; + } + Body += llvm::sys::path::convert_to_slash(AbsolutePath); + + std::string Uri = Scheme.str() + ":"; + if (Authority.empty() && Body.empty()) + return Uri; + + // If authority if empty, we only print body if it starts with "/"; otherwise, + // the URI is invalid. + if (!Authority.empty() || StringRef(Body).starts_with("/")) { + Uri.append("//"); + percentEncode(Authority, Uri); + } + percentEncode(Body, Uri); + return Uri; +} + +static llvm::Expected getAbsolutePath(StringRef Authority, + StringRef Body) { + if (!Body.starts_with("/")) + return llvm::createStringError( + llvm::inconvertibleErrorCode(), + "File scheme: expect body to be an absolute path starting " + "with '/': " + + Body); + SmallString<128> Path; + if (!Authority.empty()) { + // Windows UNC paths e.g. file://server/share => \\server\share + ("//" + Authority).toVector(Path); + } else if (isWindowsPath(Body.substr(1))) { + // Windows paths e.g. file:///X:/path => X:\path + Body.consume_front("/"); + } + Path.append(Body); + llvm::sys::path::native(Path); + return std::string(Path); +} + +static llvm::Expected parseFilePathFromURI(StringRef OrigUri) { + StringRef Uri = OrigUri; + + // Decode the scheme of the URI. + size_t Pos = Uri.find(':'); + if (Pos == StringRef::npos) + return llvm::createStringError(llvm::inconvertibleErrorCode(), + "Scheme must be provided in URI: " + + OrigUri); + StringRef SchemeStr = Uri.substr(0, Pos); + std::string UriScheme = percentDecode(SchemeStr); + if (!isStructurallyValidScheme(UriScheme)) + return llvm::createStringError(llvm::inconvertibleErrorCode(), + "Invalid scheme: " + SchemeStr + + " (decoded: " + UriScheme + ")"); + Uri = Uri.substr(Pos + 1); + + // Decode the authority of the URI. + std::string UriAuthority; + if (Uri.consume_front("//")) { + Pos = Uri.find('/'); + UriAuthority = percentDecode(Uri.substr(0, Pos)); + Uri = Uri.substr(Pos); + } + + // Decode the body of the URI. + std::string UriBody = percentDecode(Uri); + + // Compute the absolute path for this uri. + if (!getSupportedSchemes().contains(UriScheme)) { + return llvm::createStringError(llvm::inconvertibleErrorCode(), + "unsupported URI scheme `" + UriScheme + + "' for workspace files"); + } + return getAbsolutePath(UriAuthority, UriBody); +} + +llvm::Expected URIForFile::fromURI(StringRef Uri) { + llvm::Expected FilePath = parseFilePathFromURI(Uri); + if (!FilePath) + return FilePath.takeError(); + return URIForFile(std::move(*FilePath), Uri.str()); +} + +llvm::Expected URIForFile::fromFile(StringRef AbsoluteFilepath, + StringRef Scheme) { + llvm::Expected Uri = + uriFromAbsolutePath(AbsoluteFilepath, Scheme); + if (!Uri) + return Uri.takeError(); + return fromURI(*Uri); +} + +StringRef URIForFile::scheme() const { return uri().split(':').first; } + +void URIForFile::registerSupportedScheme(StringRef Scheme) { + getSupportedSchemes().insert(Scheme); +} + +bool llvm::lsp::fromJSON(const llvm::json::Value &Value, URIForFile &Result, + llvm::json::Path Path) { + if (std::optional Str = Value.getAsString()) { + llvm::Expected ExpectedUri = URIForFile::fromURI(*Str); + if (!ExpectedUri) { + Path.report("unresolvable URI"); + consumeError(ExpectedUri.takeError()); + return false; + } + Result = std::move(*ExpectedUri); + return true; + } + return false; +} + +llvm::json::Value llvm::lsp::toJSON(const URIForFile &Value) { + return Value.uri(); +} + +raw_ostream &llvm::lsp::operator<<(raw_ostream &Os, const URIForFile &Value) { + return Os << Value.uri(); +} + +//===----------------------------------------------------------------------===// +// ClientCapabilities +//===----------------------------------------------------------------------===// + +bool llvm::lsp::fromJSON(const llvm::json::Value &Value, + ClientCapabilities &Result, llvm::json::Path Path) { + const llvm::json::Object *O = Value.getAsObject(); + if (!O) { + Path.report("expected object"); + return false; + } + if (const llvm::json::Object *TextDocument = O->getObject("textDocument")) { + if (const llvm::json::Object *DocumentSymbol = + TextDocument->getObject("documentSymbol")) { + if (std::optional HierarchicalSupport = + DocumentSymbol->getBoolean("hierarchicalDocumentSymbolSupport")) + Result.hierarchicalDocumentSymbol = *HierarchicalSupport; + } + if (auto *CodeAction = TextDocument->getObject("codeAction")) { + if (CodeAction->getObject("codeActionLiteralSupport")) + Result.codeActionStructure = true; + } + } + if (auto *Window = O->getObject("window")) { + if (std::optional WorkDoneProgressSupport = + Window->getBoolean("workDoneProgress")) + Result.workDoneProgress = *WorkDoneProgressSupport; + } + return true; +} + +//===----------------------------------------------------------------------===// +// ClientInfo +//===----------------------------------------------------------------------===// + +bool llvm::lsp::fromJSON(const llvm::json::Value &Value, ClientInfo &Result, + llvm::json::Path Path) { + llvm::json::ObjectMapper O(Value, Path); + if (!O || !O.map("name", Result.name)) + return false; + + // Don't fail if we can't parse version. + O.map("version", Result.version); + return true; +} + +//===----------------------------------------------------------------------===// +// InitializeParams +//===----------------------------------------------------------------------===// + +bool llvm::lsp::fromJSON(const llvm::json::Value &Value, TraceLevel &Result, + llvm::json::Path Path) { + if (std::optional Str = Value.getAsString()) { + if (*Str == "off") { + Result = TraceLevel::Off; + return true; + } + if (*Str == "messages") { + Result = TraceLevel::Messages; + return true; + } + if (*Str == "verbose") { + Result = TraceLevel::Verbose; + return true; + } + } + return false; +} + +bool llvm::lsp::fromJSON(const llvm::json::Value &Value, + InitializeParams &Result, llvm::json::Path Path) { + llvm::json::ObjectMapper O(Value, Path); + if (!O) + return false; + // We deliberately don't fail if we can't parse individual fields. + O.map("capabilities", Result.capabilities); + O.map("trace", Result.trace); + mapOptOrNull(Value, "clientInfo", Result.clientInfo, Path); + + return true; +} + +//===----------------------------------------------------------------------===// +// TextDocumentItem +//===----------------------------------------------------------------------===// + +bool llvm::lsp::fromJSON(const llvm::json::Value &Value, + TextDocumentItem &Result, llvm::json::Path Path) { + llvm::json::ObjectMapper O(Value, Path); + return O && O.map("uri", Result.uri) && + O.map("languageId", Result.languageId) && O.map("text", Result.text) && + O.map("version", Result.version); +} + +//===----------------------------------------------------------------------===// +// TextDocumentIdentifier +//===----------------------------------------------------------------------===// + +llvm::json::Value llvm::lsp::toJSON(const TextDocumentIdentifier &Value) { + return llvm::json::Object{{"uri", Value.uri}}; +} + +bool llvm::lsp::fromJSON(const llvm::json::Value &Value, + TextDocumentIdentifier &Result, + llvm::json::Path Path) { + llvm::json::ObjectMapper O(Value, Path); + return O && O.map("uri", Result.uri); +} + +//===----------------------------------------------------------------------===// +// VersionedTextDocumentIdentifier +//===----------------------------------------------------------------------===// + +llvm::json::Value +llvm::lsp::toJSON(const VersionedTextDocumentIdentifier &Value) { + return llvm::json::Object{ + {"uri", Value.uri}, + {"version", Value.version}, + }; +} + +bool llvm::lsp::fromJSON(const llvm::json::Value &Value, + VersionedTextDocumentIdentifier &Result, + llvm::json::Path Path) { + llvm::json::ObjectMapper O(Value, Path); + return O && O.map("uri", Result.uri) && O.map("version", Result.version); +} + +//===----------------------------------------------------------------------===// +// Position +//===----------------------------------------------------------------------===// + +bool llvm::lsp::fromJSON(const llvm::json::Value &Value, Position &Result, + llvm::json::Path Path) { + llvm::json::ObjectMapper O(Value, Path); + return O && O.map("line", Result.line) && + O.map("character", Result.character); +} + +llvm::json::Value llvm::lsp::toJSON(const Position &Value) { + return llvm::json::Object{ + {"line", Value.line}, + {"character", Value.character}, + }; +} + +raw_ostream &llvm::lsp::operator<<(raw_ostream &Os, const Position &Value) { + return Os << Value.line << ':' << Value.character; +} + +//===----------------------------------------------------------------------===// +// Range +//===----------------------------------------------------------------------===// + +bool llvm::lsp::fromJSON(const llvm::json::Value &Value, Range &Result, + llvm::json::Path Path) { + llvm::json::ObjectMapper O(Value, Path); + return O && O.map("start", Result.start) && O.map("end", Result.end); +} + +llvm::json::Value llvm::lsp::toJSON(const Range &Value) { + return llvm::json::Object{ + {"start", Value.start}, + {"end", Value.end}, + }; +} + +raw_ostream &llvm::lsp::operator<<(raw_ostream &Os, const Range &Value) { + return Os << Value.start << '-' << Value.end; +} + +//===----------------------------------------------------------------------===// +// Location +//===----------------------------------------------------------------------===// + +bool llvm::lsp::fromJSON(const llvm::json::Value &Value, Location &Result, + llvm::json::Path Path) { + llvm::json::ObjectMapper O(Value, Path); + return O && O.map("uri", Result.uri) && O.map("range", Result.range); +} + +llvm::json::Value llvm::lsp::toJSON(const Location &Value) { + return llvm::json::Object{ + {"uri", Value.uri}, + {"range", Value.range}, + }; +} + +raw_ostream &llvm::lsp::operator<<(raw_ostream &Os, const Location &Value) { + return Os << Value.range << '@' << Value.uri; +} + +//===----------------------------------------------------------------------===// +// TextDocumentPositionParams +//===----------------------------------------------------------------------===// + +bool llvm::lsp::fromJSON(const llvm::json::Value &Value, + TextDocumentPositionParams &Result, + llvm::json::Path Path) { + llvm::json::ObjectMapper O(Value, Path); + return O && O.map("textDocument", Result.textDocument) && + O.map("position", Result.position); +} + +//===----------------------------------------------------------------------===// +// ReferenceParams +//===----------------------------------------------------------------------===// + +bool llvm::lsp::fromJSON(const llvm::json::Value &Value, + ReferenceContext &Result, llvm::json::Path Path) { + llvm::json::ObjectMapper O(Value, Path); + return O && O.mapOptional("includeDeclaration", Result.includeDeclaration); +} + +bool llvm::lsp::fromJSON(const llvm::json::Value &Value, + ReferenceParams &Result, llvm::json::Path Path) { + TextDocumentPositionParams &Base = Result; + llvm::json::ObjectMapper O(Value, Path); + return fromJSON(Value, Base, Path) && O && + O.mapOptional("context", Result.context); +} + +//===----------------------------------------------------------------------===// +// DidOpenTextDocumentParams +//===----------------------------------------------------------------------===// + +bool llvm::lsp::fromJSON(const llvm::json::Value &Value, + DidOpenTextDocumentParams &Result, + llvm::json::Path Path) { + llvm::json::ObjectMapper O(Value, Path); + return O && O.map("textDocument", Result.textDocument); +} + +//===----------------------------------------------------------------------===// +// DidCloseTextDocumentParams +//===----------------------------------------------------------------------===// + +bool llvm::lsp::fromJSON(const llvm::json::Value &Value, + DidCloseTextDocumentParams &Result, + llvm::json::Path Path) { + llvm::json::ObjectMapper O(Value, Path); + return O && O.map("textDocument", Result.textDocument); +} + +//===----------------------------------------------------------------------===// +// DidChangeTextDocumentParams +//===----------------------------------------------------------------------===// + +LogicalResult +TextDocumentContentChangeEvent::applyTo(std::string &Contents) const { + // If there is no range, the full document changed. + if (!range) { + Contents = text; + return success(); + } + + // Try to map the replacement range to the content. + llvm::SourceMgr TmpScrMgr; + TmpScrMgr.AddNewSourceBuffer(llvm::MemoryBuffer::getMemBuffer(Contents), + SMLoc()); + SMRange RangeLoc = range->getAsSMRange(TmpScrMgr); + if (!RangeLoc.isValid()) + return failure(); + + Contents.replace(RangeLoc.Start.getPointer() - Contents.data(), + RangeLoc.End.getPointer() - RangeLoc.Start.getPointer(), + text); + return success(); +} + +LogicalResult TextDocumentContentChangeEvent::applyTo( + ArrayRef Changes, std::string &Contents) { + for (const auto &Change : Changes) + if (failed(Change.applyTo(Contents))) + return failure(); + return success(); +} + +bool llvm::lsp::fromJSON(const llvm::json::Value &Value, + TextDocumentContentChangeEvent &Result, + llvm::json::Path Path) { + llvm::json::ObjectMapper O(Value, Path); + return O && O.map("range", Result.range) && + O.map("rangeLength", Result.rangeLength) && O.map("text", Result.text); +} + +bool llvm::lsp::fromJSON(const llvm::json::Value &Value, + DidChangeTextDocumentParams &Result, + llvm::json::Path Path) { + llvm::json::ObjectMapper O(Value, Path); + return O && O.map("textDocument", Result.textDocument) && + O.map("contentChanges", Result.contentChanges); +} + +//===----------------------------------------------------------------------===// +// MarkupContent +//===----------------------------------------------------------------------===// + +static llvm::StringRef toTextKind(MarkupKind Kind) { + switch (Kind) { + case MarkupKind::PlainText: + return "plaintext"; + case MarkupKind::Markdown: + return "markdown"; + } + llvm_unreachable("Invalid MarkupKind"); +} + +raw_ostream &llvm::lsp::operator<<(raw_ostream &Os, MarkupKind Kind) { + return Os << toTextKind(Kind); +} + +llvm::json::Value llvm::lsp::toJSON(const MarkupContent &Mc) { + if (Mc.value.empty()) + return nullptr; + + return llvm::json::Object{ + {"kind", toTextKind(Mc.kind)}, + {"value", Mc.value}, + }; +} + +//===----------------------------------------------------------------------===// +// Hover +//===----------------------------------------------------------------------===// + +llvm::json::Value llvm::lsp::toJSON(const Hover &Hover) { + llvm::json::Object Result{{"contents", toJSON(Hover.contents)}}; + if (Hover.range) + Result["range"] = toJSON(*Hover.range); + return std::move(Result); +} + +//===----------------------------------------------------------------------===// +// DocumentSymbol +//===----------------------------------------------------------------------===// + +llvm::json::Value llvm::lsp::toJSON(const DocumentSymbol &Symbol) { + llvm::json::Object Result{{"name", Symbol.name}, + {"kind", static_cast(Symbol.kind)}, + {"range", Symbol.range}, + {"selectionRange", Symbol.selectionRange}}; + + if (!Symbol.detail.empty()) + Result["detail"] = Symbol.detail; + if (!Symbol.children.empty()) + Result["children"] = Symbol.children; + return std::move(Result); +} + +//===----------------------------------------------------------------------===// +// DocumentSymbolParams +//===----------------------------------------------------------------------===// + +bool llvm::lsp::fromJSON(const llvm::json::Value &Value, + DocumentSymbolParams &Result, llvm::json::Path Path) { + llvm::json::ObjectMapper O(Value, Path); + return O && O.map("textDocument", Result.textDocument); +} + +//===----------------------------------------------------------------------===// +// DiagnosticRelatedInformation +//===----------------------------------------------------------------------===// + +bool llvm::lsp::fromJSON(const llvm::json::Value &Value, + DiagnosticRelatedInformation &Result, + llvm::json::Path Path) { + llvm::json::ObjectMapper O(Value, Path); + return O && O.map("location", Result.location) && + O.map("message", Result.message); +} + +llvm::json::Value llvm::lsp::toJSON(const DiagnosticRelatedInformation &Info) { + return llvm::json::Object{ + {"location", Info.location}, + {"message", Info.message}, + }; +} + +//===----------------------------------------------------------------------===// +// Diagnostic +//===----------------------------------------------------------------------===// + +llvm::json::Value llvm::lsp::toJSON(DiagnosticTag Tag) { + return static_cast(Tag); +} + +bool llvm::lsp::fromJSON(const llvm::json::Value &Value, DiagnosticTag &Result, + llvm::json::Path Path) { + if (std::optional I = Value.getAsInteger()) { + Result = (DiagnosticTag)*I; + return true; + } + + return false; +} + +llvm::json::Value llvm::lsp::toJSON(const Diagnostic &Diag) { + llvm::json::Object Result{ + {"range", Diag.range}, + {"severity", (int)Diag.severity}, + {"message", Diag.message}, + }; + if (Diag.category) + Result["category"] = *Diag.category; + if (!Diag.source.empty()) + Result["source"] = Diag.source; + if (Diag.relatedInformation) + Result["relatedInformation"] = *Diag.relatedInformation; + if (!Diag.tags.empty()) + Result["tags"] = Diag.tags; + return std::move(Result); +} + +bool llvm::lsp::fromJSON(const llvm::json::Value &Value, Diagnostic &Result, + llvm::json::Path Path) { + llvm::json::ObjectMapper O(Value, Path); + if (!O) + return false; + int Severity = 0; + if (!mapOptOrNull(Value, "severity", Severity, Path)) + return false; + Result.severity = (DiagnosticSeverity)Severity; + + return O.map("range", Result.range) && O.map("message", Result.message) && + mapOptOrNull(Value, "category", Result.category, Path) && + mapOptOrNull(Value, "source", Result.source, Path) && + mapOptOrNull(Value, "relatedInformation", Result.relatedInformation, + Path) && + mapOptOrNull(Value, "tags", Result.tags, Path); +} + +//===----------------------------------------------------------------------===// +// PublishDiagnosticsParams +//===----------------------------------------------------------------------===// + +llvm::json::Value llvm::lsp::toJSON(const PublishDiagnosticsParams &Params) { + return llvm::json::Object{ + {"uri", Params.uri}, + {"diagnostics", Params.diagnostics}, + {"version", Params.version}, + }; +} + +//===----------------------------------------------------------------------===// +// TextEdit +//===----------------------------------------------------------------------===// + +bool llvm::lsp::fromJSON(const llvm::json::Value &Value, TextEdit &Result, + llvm::json::Path Path) { + llvm::json::ObjectMapper O(Value, Path); + return O && O.map("range", Result.range) && O.map("newText", Result.newText); +} + +llvm::json::Value llvm::lsp::toJSON(const TextEdit &Value) { + return llvm::json::Object{ + {"range", Value.range}, + {"newText", Value.newText}, + }; +} + +raw_ostream &llvm::lsp::operator<<(raw_ostream &Os, const TextEdit &Value) { + Os << Value.range << " => \""; + llvm::printEscapedString(Value.newText, Os); + return Os << '"'; +} + +//===----------------------------------------------------------------------===// +// CompletionItemKind +//===----------------------------------------------------------------------===// + +bool llvm::lsp::fromJSON(const llvm::json::Value &Value, + CompletionItemKind &Result, llvm::json::Path Path) { + if (std::optional IntValue = Value.getAsInteger()) { + if (*IntValue < static_cast(CompletionItemKind::Text) || + *IntValue > static_cast(CompletionItemKind::TypeParameter)) + return false; + Result = static_cast(*IntValue); + return true; + } + return false; +} + +CompletionItemKind llvm::lsp::adjustKindToCapability( + CompletionItemKind Kind, + CompletionItemKindBitset &SupportedCompletionItemKinds) { + size_t KindVal = static_cast(Kind); + if (KindVal >= kCompletionItemKindMin && + KindVal <= SupportedCompletionItemKinds.size() && + SupportedCompletionItemKinds[KindVal]) + return Kind; + + // Provide some fall backs for common kinds that are close enough. + switch (Kind) { + case CompletionItemKind::Folder: + return CompletionItemKind::File; + case CompletionItemKind::EnumMember: + return CompletionItemKind::Enum; + case CompletionItemKind::Struct: + return CompletionItemKind::Class; + default: + return CompletionItemKind::Text; + } +} + +bool llvm::lsp::fromJSON(const llvm::json::Value &Value, + CompletionItemKindBitset &Result, + llvm::json::Path Path) { + if (const llvm::json::Array *ArrayValue = Value.getAsArray()) { + for (size_t I = 0, E = ArrayValue->size(); I < E; ++I) { + CompletionItemKind KindOut; + if (fromJSON((*ArrayValue)[I], KindOut, Path.index(I))) + Result.set(size_t(KindOut)); + } + return true; + } + return false; +} + +//===----------------------------------------------------------------------===// +// CompletionItem +//===----------------------------------------------------------------------===// + +llvm::json::Value llvm::lsp::toJSON(const CompletionItem &Value) { + assert(!Value.label.empty() && "completion item label is required"); + llvm::json::Object Result{{"label", Value.label}}; + if (Value.kind != CompletionItemKind::Missing) + Result["kind"] = static_cast(Value.kind); + if (!Value.detail.empty()) + Result["detail"] = Value.detail; + if (Value.documentation) + Result["documentation"] = Value.documentation; + if (!Value.sortText.empty()) + Result["sortText"] = Value.sortText; + if (!Value.filterText.empty()) + Result["filterText"] = Value.filterText; + if (!Value.insertText.empty()) + Result["insertText"] = Value.insertText; + if (Value.insertTextFormat != InsertTextFormat::Missing) + Result["insertTextFormat"] = static_cast(Value.insertTextFormat); + if (Value.textEdit) + Result["textEdit"] = *Value.textEdit; + if (!Value.additionalTextEdits.empty()) { + Result["additionalTextEdits"] = + llvm::json::Array(Value.additionalTextEdits); + } + if (Value.deprecated) + Result["deprecated"] = Value.deprecated; + return std::move(Result); +} + +raw_ostream &llvm::lsp::operator<<(raw_ostream &Os, + const CompletionItem &Value) { + return Os << Value.label << " - " << toJSON(Value); +} + +bool llvm::lsp::operator<(const CompletionItem &Lhs, + const CompletionItem &Rhs) { + return (Lhs.sortText.empty() ? Lhs.label : Lhs.sortText) < + (Rhs.sortText.empty() ? Rhs.label : Rhs.sortText); +} + +//===----------------------------------------------------------------------===// +// CompletionList +//===----------------------------------------------------------------------===// + +llvm::json::Value llvm::lsp::toJSON(const CompletionList &Value) { + return llvm::json::Object{ + {"isIncomplete", Value.isIncomplete}, + {"items", llvm::json::Array(Value.items)}, + }; +} + +//===----------------------------------------------------------------------===// +// CompletionContext +//===----------------------------------------------------------------------===// + +bool llvm::lsp::fromJSON(const llvm::json::Value &Value, + CompletionContext &Result, llvm::json::Path Path) { + llvm::json::ObjectMapper O(Value, Path); + int TriggerKind; + if (!O || !O.map("triggerKind", TriggerKind) || + !mapOptOrNull(Value, "triggerCharacter", Result.triggerCharacter, Path)) + return false; + Result.triggerKind = static_cast(TriggerKind); + return true; +} + +//===----------------------------------------------------------------------===// +// CompletionParams +//===----------------------------------------------------------------------===// + +bool llvm::lsp::fromJSON(const llvm::json::Value &Value, + CompletionParams &Result, llvm::json::Path Path) { + if (!fromJSON(Value, static_cast(Result), Path)) + return false; + if (const llvm::json::Value *Context = Value.getAsObject()->get("context")) + return fromJSON(*Context, Result.context, Path.field("context")); + return true; +} + +//===----------------------------------------------------------------------===// +// ParameterInformation +//===----------------------------------------------------------------------===// + +llvm::json::Value llvm::lsp::toJSON(const ParameterInformation &Value) { + assert((Value.labelOffsets || !Value.labelString.empty()) && + "parameter information label is required"); + llvm::json::Object Result; + if (Value.labelOffsets) + Result["label"] = llvm::json::Array( + {Value.labelOffsets->first, Value.labelOffsets->second}); + else + Result["label"] = Value.labelString; + if (!Value.documentation.empty()) + Result["documentation"] = Value.documentation; + return std::move(Result); +} + +//===----------------------------------------------------------------------===// +// SignatureInformation +//===----------------------------------------------------------------------===// + +llvm::json::Value llvm::lsp::toJSON(const SignatureInformation &Value) { + assert(!Value.label.empty() && "signature information label is required"); + llvm::json::Object Result{ + {"label", Value.label}, + {"parameters", llvm::json::Array(Value.parameters)}, + }; + if (!Value.documentation.empty()) + Result["documentation"] = Value.documentation; + return std::move(Result); +} + +raw_ostream &llvm::lsp::operator<<(raw_ostream &Os, + const SignatureInformation &Value) { + return Os << Value.label << " - " << toJSON(Value); +} + +//===----------------------------------------------------------------------===// +// SignatureHelp +//===----------------------------------------------------------------------===// + +llvm::json::Value llvm::lsp::toJSON(const SignatureHelp &Value) { + assert(Value.activeSignature >= 0 && + "Unexpected negative value for number of active signatures."); + assert(Value.activeParameter >= 0 && + "Unexpected negative value for active parameter index"); + return llvm::json::Object{ + {"activeSignature", Value.activeSignature}, + {"activeParameter", Value.activeParameter}, + {"signatures", llvm::json::Array(Value.signatures)}, + }; +} + +//===----------------------------------------------------------------------===// +// DocumentLinkParams +//===----------------------------------------------------------------------===// + +bool llvm::lsp::fromJSON(const llvm::json::Value &Value, + DocumentLinkParams &Result, llvm::json::Path Path) { + llvm::json::ObjectMapper O(Value, Path); + return O && O.map("textDocument", Result.textDocument); +} + +//===----------------------------------------------------------------------===// +// DocumentLink +//===----------------------------------------------------------------------===// + +llvm::json::Value llvm::lsp::toJSON(const DocumentLink &Value) { + return llvm::json::Object{ + {"range", Value.range}, + {"target", Value.target}, + }; +} + +//===----------------------------------------------------------------------===// +// InlayHintsParams +//===----------------------------------------------------------------------===// + +bool llvm::lsp::fromJSON(const llvm::json::Value &Value, + InlayHintsParams &Result, llvm::json::Path Path) { + llvm::json::ObjectMapper O(Value, Path); + return O && O.map("textDocument", Result.textDocument) && + O.map("range", Result.range); +} + +//===----------------------------------------------------------------------===// +// InlayHint +//===----------------------------------------------------------------------===// + +llvm::json::Value llvm::lsp::toJSON(const InlayHint &Value) { + return llvm::json::Object{{"position", Value.position}, + {"kind", (int)Value.kind}, + {"label", Value.label}, + {"paddingLeft", Value.paddingLeft}, + {"paddingRight", Value.paddingRight}}; +} +bool llvm::lsp::operator==(const InlayHint &Lhs, const InlayHint &Rhs) { + return std::tie(Lhs.position, Lhs.kind, Lhs.label) == + std::tie(Rhs.position, Rhs.kind, Rhs.label); +} +bool llvm::lsp::operator<(const InlayHint &Lhs, const InlayHint &Rhs) { + return std::tie(Lhs.position, Lhs.kind, Lhs.label) < + std::tie(Rhs.position, Rhs.kind, Rhs.label); +} + +llvm::raw_ostream &llvm::lsp::operator<<(llvm::raw_ostream &Os, + InlayHintKind Value) { + switch (Value) { + case InlayHintKind::Parameter: + return Os << "parameter"; + case InlayHintKind::Type: + return Os << "type"; + } + llvm_unreachable("Unknown InlayHintKind"); +} + +//===----------------------------------------------------------------------===// +// CodeActionContext +//===----------------------------------------------------------------------===// + +bool llvm::lsp::fromJSON(const llvm::json::Value &Value, + CodeActionContext &Result, llvm::json::Path Path) { + llvm::json::ObjectMapper O(Value, Path); + if (!O || !O.map("diagnostics", Result.diagnostics)) + return false; + O.map("only", Result.only); + return true; +} + +//===----------------------------------------------------------------------===// +// CodeActionParams +//===----------------------------------------------------------------------===// + +bool llvm::lsp::fromJSON(const llvm::json::Value &Value, + CodeActionParams &Result, llvm::json::Path Path) { + llvm::json::ObjectMapper O(Value, Path); + return O && O.map("textDocument", Result.textDocument) && + O.map("range", Result.range) && O.map("context", Result.context); +} + +//===----------------------------------------------------------------------===// +// WorkspaceEdit +//===----------------------------------------------------------------------===// + +bool llvm::lsp::fromJSON(const llvm::json::Value &Value, WorkspaceEdit &Result, + llvm::json::Path Path) { + llvm::json::ObjectMapper O(Value, Path); + return O && O.map("changes", Result.changes); +} + +llvm::json::Value llvm::lsp::toJSON(const WorkspaceEdit &Value) { + llvm::json::Object FileChanges; + for (auto &Change : Value.changes) + FileChanges[Change.first] = llvm::json::Array(Change.second); + return llvm::json::Object{{"changes", std::move(FileChanges)}}; +} + +//===----------------------------------------------------------------------===// +// CodeAction +//===----------------------------------------------------------------------===// + +const llvm::StringLiteral CodeAction::kQuickFix = "quickfix"; +const llvm::StringLiteral CodeAction::kRefactor = "refactor"; +const llvm::StringLiteral CodeAction::kInfo = "info"; + +llvm::json::Value llvm::lsp::toJSON(const CodeAction &Value) { + llvm::json::Object CodeAction{{"title", Value.title}}; + if (Value.kind) + CodeAction["kind"] = *Value.kind; + if (Value.diagnostics) + CodeAction["diagnostics"] = llvm::json::Array(*Value.diagnostics); + if (Value.isPreferred) + CodeAction["isPreferred"] = true; + if (Value.edit) + CodeAction["edit"] = *Value.edit; + return std::move(CodeAction); +} diff --git a/llvm/lib/Support/LSP/Transport.cpp b/llvm/lib/Support/LSP/Transport.cpp new file mode 100644 index 0000000000000..e71f17701636b --- /dev/null +++ b/llvm/lib/Support/LSP/Transport.cpp @@ -0,0 +1,369 @@ +//===--- JSONTransport.cpp - sending and receiving LSP messages over JSON -===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/Support/LSP/Transport.h" +#include "llvm/ADT/SmallString.h" +#include "llvm/Support/Error.h" +#include "llvm/Support/LSP/Logging.h" +#include "llvm/Support/LSP/Protocol.h" +#include +#include +#include +#include + +using namespace llvm; +using namespace llvm::lsp; + +//===----------------------------------------------------------------------===// +// Reply +//===----------------------------------------------------------------------===// + +namespace { +/// Function object to reply to an LSP call. +/// Each instance must be called exactly once, otherwise: +/// - if there was no reply, an error reply is sent +/// - if there were multiple replies, only the first is sent +class Reply { +public: + Reply(const llvm::json::Value &Id, StringRef Method, JSONTransport &Transport, + std::mutex &TransportOutputMutex); + Reply(Reply &&Other); + Reply &operator=(Reply &&) = delete; + Reply(const Reply &) = delete; + Reply &operator=(const Reply &) = delete; + + void operator()(llvm::Expected Reply); + +private: + std::string Method; + std::atomic Replied = {false}; + llvm::json::Value Id; + JSONTransport *Transport; + std::mutex &TransportOutputMutex; +}; +} // namespace + +Reply::Reply(const llvm::json::Value &Id, llvm::StringRef Method, + JSONTransport &Transport, std::mutex &TransportOutputMutex) + : Method(Method), Id(Id), Transport(&Transport), + TransportOutputMutex(TransportOutputMutex) {} + +Reply::Reply(Reply &&Other) + : Method(Other.Method), Replied(Other.Replied.load()), + Id(std::move(Other.Id)), Transport(Other.Transport), + TransportOutputMutex(Other.TransportOutputMutex) { + Other.Transport = nullptr; +} + +void Reply::operator()(llvm::Expected Reply) { + if (Replied.exchange(true)) { + Logger::error("Replied twice to message {0}({1})", Method, Id); + assert(false && "must reply to each call only once!"); + return; + } + assert(Transport && "expected valid transport to reply to"); + + std::lock_guard TransportLock(TransportOutputMutex); + if (Reply) { + Logger::info("--> reply:{0}({1})", Method, Id); + Transport->reply(std::move(Id), std::move(Reply)); + } else { + llvm::Error Error = Reply.takeError(); + Logger::info("--> reply:{0}({1}): {2}", Method, Id, Error); + Transport->reply(std::move(Id), std::move(Error)); + } +} + +//===----------------------------------------------------------------------===// +// MessageHandler +//===----------------------------------------------------------------------===// + +bool MessageHandler::onNotify(llvm::StringRef Method, llvm::json::Value Value) { + Logger::info("--> {0}", Method); + + if (Method == "exit") + return false; + if (Method == "$cancel") { + // TODO: Add support for cancelling requests. + } else { + auto It = NotificationHandlers.find(Method); + if (It != NotificationHandlers.end()) + It->second(std::move(Value)); + } + return true; +} + +bool MessageHandler::onCall(llvm::StringRef Method, llvm::json::Value Params, + llvm::json::Value Id) { + Logger::info("--> {0}({1})", Method, Id); + + Reply Reply(Id, Method, Transport, TransportOutputMutex); + + auto It = MethodHandlers.find(Method); + if (It != MethodHandlers.end()) { + It->second(std::move(Params), std::move(Reply)); + } else { + Reply(llvm::make_error("method not found: " + Method.str(), + ErrorCode::MethodNotFound)); + } + return true; +} + +bool MessageHandler::onReply(llvm::json::Value Id, + llvm::Expected Result) { + // Find the response handler in the mapping. If it exists, move it out of the + // mapping and erase it. + ResponseHandlerTy ResponseHandler; + { + std::lock_guard responseHandlersLock(ResponseHandlerTy); + auto It = ResponseHandlers.find(debugString(Id)); + if (It != ResponseHandlers.end()) { + ResponseHandler = std::move(It->second); + ResponseHandlers.erase(It); + } + } + + // If we found a response handler, invoke it. Otherwise, log an error. + if (ResponseHandler.second) { + Logger::info("--> reply:{0}({1})", ResponseHandler.first, Id); + ResponseHandler.second(std::move(Id), std::move(Result)); + } else { + Logger::error( + "received a reply with ID {0}, but there was no such outgoing request", + Id); + if (!Result) + llvm::consumeError(Result.takeError()); + } + return true; +} + +//===----------------------------------------------------------------------===// +// JSONTransport +//===----------------------------------------------------------------------===// + +/// Encode the given error as a JSON object. +static llvm::json::Object encodeError(llvm::Error Error) { + std::string Message; + ErrorCode Code = ErrorCode::UnknownErrorCode; + auto HandlerFn = [&](const LSPError &LspError) -> llvm::Error { + Message = LspError.message; + Code = LspError.code; + return llvm::Error::success(); + }; + if (llvm::Error Unhandled = llvm::handleErrors(std::move(Error), HandlerFn)) + Message = llvm::toString(std::move(Unhandled)); + + return llvm::json::Object{ + {"message", std::move(Message)}, + {"code", int64_t(Code)}, + }; +} + +/// Decode the given JSON object into an error. +llvm::Error decodeError(const llvm::json::Object &O) { + StringRef Msg = O.getString("message").value_or("Unspecified error"); + if (std::optional Code = O.getInteger("code")) + return llvm::make_error(Msg.str(), ErrorCode(*Code)); + return llvm::make_error(llvm::inconvertibleErrorCode(), + Msg.str()); +} + +void JSONTransport::notify(StringRef Method, llvm::json::Value Params) { + sendMessage(llvm::json::Object{ + {"jsonrpc", "2.0"}, + {"method", Method}, + {"params", std::move(Params)}, + }); +} +void JSONTransport::call(StringRef Method, llvm::json::Value Params, + llvm::json::Value Id) { + sendMessage(llvm::json::Object{ + {"jsonrpc", "2.0"}, + {"id", std::move(Id)}, + {"method", Method}, + {"params", std::move(Params)}, + }); +} +void JSONTransport::reply(llvm::json::Value Id, + llvm::Expected Result) { + if (Result) { + return sendMessage(llvm::json::Object{ + {"jsonrpc", "2.0"}, + {"id", std::move(Id)}, + {"result", std::move(*Result)}, + }); + } + + sendMessage(llvm::json::Object{ + {"jsonrpc", "2.0"}, + {"id", std::move(Id)}, + {"error", encodeError(Result.takeError())}, + }); +} + +llvm::Error JSONTransport::run(MessageHandler &Handler) { + std::string Json; + while (!In->isEndOfInput()) { + if (In->hasError()) { + return llvm::errorCodeToError( + std::error_code(errno, std::system_category())); + } + + if (succeeded(In->readMessage(Json))) { + if (llvm::Expected Doc = llvm::json::parse(Json)) { + if (!handleMessage(std::move(*Doc), Handler)) + return llvm::Error::success(); + } else { + Logger::error("JSON parse error: {0}", llvm::toString(Doc.takeError())); + } + } + } + return llvm::errorCodeToError(std::make_error_code(std::errc::io_error)); +} + +void JSONTransport::sendMessage(llvm::json::Value Msg) { + OutputBuffer.clear(); + llvm::raw_svector_ostream os(OutputBuffer); + os << llvm::formatv(PrettyOutput ? "{0:2}\n" : "{0}", Msg); + Out << "Content-Length: " << OutputBuffer.size() << "\r\n\r\n" + << OutputBuffer; + Out.flush(); + Logger::debug(">>> {0}\n", OutputBuffer); +} + +bool JSONTransport::handleMessage(llvm::json::Value Msg, + MessageHandler &Handler) { + // Message must be an object with "jsonrpc":"2.0". + llvm::json::Object *Object = Msg.getAsObject(); + if (!Object || + Object->getString("jsonrpc") != std::optional("2.0")) + return false; + + // `id` may be any JSON value. If absent, this is a notification. + std::optional Id; + if (llvm::json::Value *I = Object->get("id")) + Id = std::move(*I); + std::optional Method = Object->getString("method"); + + // This is a response. + if (!Method) { + if (!Id) + return false; + if (auto *Err = Object->getObject("error")) + return Handler.onReply(std::move(*Id), decodeError(*Err)); + // result should be given, use null if not. + llvm::json::Value Result = nullptr; + if (llvm::json::Value *R = Object->get("result")) + Result = std::move(*R); + return Handler.onReply(std::move(*Id), std::move(Result)); + } + + // Params should be given, use null if not. + llvm::json::Value Params = nullptr; + if (llvm::json::Value *P = Object->get("params")) + Params = std::move(*P); + + if (Id) + return Handler.onCall(*Method, std::move(Params), std::move(*Id)); + return Handler.onNotify(*Method, std::move(Params)); +} + +/// Tries to read a line up to and including \n. +/// If failing, feof(), ferror(), or shutdownRequested() will be set. +LogicalResult readLine(std::FILE *In, SmallVectorImpl &Out) { + // Big enough to hold any reasonable header line. May not fit content lines + // in delimited mode, but performance doesn't matter for that mode. + static constexpr int BufSize = 128; + size_t Size = 0; + Out.clear(); + for (;;) { + Out.resize_for_overwrite(Size + BufSize); + if (!std::fgets(&Out[Size], BufSize, In)) + return failure(); + + clearerr(In); + + // If the line contained null bytes, anything after it (including \n) will + // be ignored. Fortunately this is not a legal header or JSON. + size_t Read = std::strlen(&Out[Size]); + if (Read > 0 && Out[Size + Read - 1] == '\n') { + Out.resize(Size + Read); + return success(); + } + Size += Read; + } +} + +// Returns std::nullopt when: +// - ferror(), feof(), or shutdownRequested() are set. +// - Content-Length is missing or empty (protocol error) +LogicalResult +JSONTransportInputOverFile::readStandardMessage(std::string &Json) { + // A Language Server Protocol message starts with a set of HTTP headers, + // delimited by \r\n, and terminated by an empty line (\r\n). + unsigned long long ContentLength = 0; + llvm::SmallString<128> Line; + while (true) { + if (feof(In) || hasError() || failed(readLine(In, Line))) + return failure(); + + // Content-Length is a mandatory header, and the only one we handle. + StringRef LineRef = Line; + if (LineRef.consume_front("Content-Length: ")) { + llvm::getAsUnsignedInteger(LineRef.trim(), 0, ContentLength); + } else if (!LineRef.trim().empty()) { + // It's another header, ignore it. + continue; + } else { + // An empty line indicates the end of headers. Go ahead and read the JSON. + break; + } + } + + // The fuzzer likes crashing us by sending "Content-Length: 9999999999999999" + if (ContentLength == 0 || ContentLength > 1 << 30) + return failure(); + + Json.resize(ContentLength); + for (size_t Pos = 0, Read; Pos < ContentLength; Pos += Read) { + Read = std::fread(&Json[Pos], 1, ContentLength - Pos, In); + if (Read == 0) + return failure(); + + // If we're done, the error was transient. If we're not done, either it was + // transient or we'll see it again on retry. + clearerr(In); + Pos += Read; + } + return success(); +} + +/// For lit tests we support a simplified syntax: +/// - messages are delimited by '// -----' on a line by itself +/// - lines starting with // are ignored. +/// This is a testing path, so favor simplicity over performance here. +/// When returning failure: feof(), ferror(), or shutdownRequested() will be +/// set. +LogicalResult +JSONTransportInputOverFile::readDelimitedMessage(std::string &Json) { + Json.clear(); + llvm::SmallString<128> Line; + while (succeeded(readLine(In, Line))) { + StringRef LineRef = Line.str().trim(); + if (LineRef.starts_with("//")) { + // Found a delimiter for the message. + if (LineRef == "// -----") + break; + continue; + } + + Json += Line; + } + + return failure(ferror(In)); +} diff --git a/llvm/unittests/Support/CMakeLists.txt b/llvm/unittests/Support/CMakeLists.txt index 0910a0b296dd0..d1dfb1dc4a722 100644 --- a/llvm/unittests/Support/CMakeLists.txt +++ b/llvm/unittests/Support/CMakeLists.txt @@ -125,6 +125,8 @@ add_llvm_unittest(SupportTests intrinsics_gen ) +add_subdirectory(LSP) + target_link_libraries(SupportTests PRIVATE LLVMTestingSupport) # Disable all warning for AlignOfTest.cpp, diff --git a/llvm/unittests/Support/LSP/CMakeLists.txt b/llvm/unittests/Support/LSP/CMakeLists.txt new file mode 100644 index 0000000000000..790a8b725469b --- /dev/null +++ b/llvm/unittests/Support/LSP/CMakeLists.txt @@ -0,0 +1,8 @@ +set(LLVM_LINK_COMPONENTS + SupportLSP +) + +add_llvm_unittest(LLVMSupportLSPTests + Protocol.cpp + Transport.cpp +) diff --git a/mlir/unittests/Tools/lsp-server-support/Protocol.cpp b/llvm/unittests/Support/LSP/Protocol.cpp similarity index 93% rename from mlir/unittests/Tools/lsp-server-support/Protocol.cpp rename to llvm/unittests/Support/LSP/Protocol.cpp index 04d7b2fbb440f..43c548c24b38b 100644 --- a/mlir/unittests/Tools/lsp-server-support/Protocol.cpp +++ b/llvm/unittests/Support/LSP/Protocol.cpp @@ -6,12 +6,12 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Tools/lsp-server-support/Protocol.h" +#include "llvm/Support/LSP/Protocol.h" #include "gtest/gtest.h" -using namespace mlir; -using namespace mlir::lsp; +using namespace llvm; +using namespace llvm::lsp; using namespace testing; namespace { diff --git a/mlir/unittests/Tools/lsp-server-support/Transport.cpp b/llvm/unittests/Support/LSP/Transport.cpp similarity index 96% rename from mlir/unittests/Tools/lsp-server-support/Transport.cpp rename to llvm/unittests/Support/LSP/Transport.cpp index 92581bd2bad08..514e93e983523 100644 --- a/mlir/unittests/Tools/lsp-server-support/Transport.cpp +++ b/llvm/unittests/Support/LSP/Transport.cpp @@ -6,15 +6,15 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Tools/lsp-server-support/Transport.h" -#include "mlir/Tools/lsp-server-support/Logging.h" -#include "mlir/Tools/lsp-server-support/Protocol.h" +#include "llvm/Support/LSP/Transport.h" #include "llvm/Support/FileSystem.h" +#include "llvm/Support/LSP/Logging.h" +#include "llvm/Support/LSP/Protocol.h" #include "gmock/gmock.h" #include "gtest/gtest.h" -using namespace mlir; -using namespace mlir::lsp; +using namespace llvm; +using namespace llvm::lsp; using namespace testing; namespace { @@ -88,7 +88,7 @@ class TransportInputTest : public Test { TEST_F(TransportInputTest, RequestWithInvalidParams) { struct Handler { void onMethod(const TextDocumentItem ¶ms, - mlir::lsp::Callback callback) {} + llvm::lsp::Callback callback) {} } handler; getMessageHandler().method("invalid-params-request", &handler, &Handler::onMethod); diff --git a/mlir/include/mlir/Tools/lsp-server-support/SourceMgrUtils.h b/mlir/include/mlir/Tools/lsp-server-support/SourceMgrUtils.h index 9ed8326a602e6..920ce831e42b6 100644 --- a/mlir/include/mlir/Tools/lsp-server-support/SourceMgrUtils.h +++ b/mlir/include/mlir/Tools/lsp-server-support/SourceMgrUtils.h @@ -14,7 +14,8 @@ #ifndef MLIR_TOOLS_LSPSERVERSUPPORT_SOURCEMGRUTILS_H #define MLIR_TOOLS_LSPSERVERSUPPORT_SOURCEMGRUTILS_H -#include "mlir/Tools/lsp-server-support/Protocol.h" +#include "mlir/Support/LLVM.h" +#include "llvm/Support/LSP/Protocol.h" #include "llvm/Support/SourceMgr.h" #include @@ -45,17 +46,18 @@ bool contains(SMRange range, SMLoc loc); /// This class represents a single include within a root file. struct SourceMgrInclude { - SourceMgrInclude(const lsp::URIForFile &uri, const lsp::Range &range) + SourceMgrInclude(const llvm::lsp::URIForFile &uri, + const llvm::lsp::Range &range) : uri(uri), range(range) {} /// Build a hover for the current include file. - Hover buildHover() const; + llvm::lsp::Hover buildHover() const; /// The URI of the file that is included. - lsp::URIForFile uri; + llvm::lsp::URIForFile uri; /// The range of the include directive. - lsp::Range range; + llvm::lsp::Range range; }; /// Given a source manager, gather all of the processed include files. These are diff --git a/mlir/include/mlir/Tools/lsp-server-support/Transport.h b/mlir/include/mlir/Tools/lsp-server-support/Transport.h deleted file mode 100644 index 0010a475fedd2..0000000000000 --- a/mlir/include/mlir/Tools/lsp-server-support/Transport.h +++ /dev/null @@ -1,283 +0,0 @@ -//===--- Transport.h - Sending and Receiving LSP messages -------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// The language server protocol is usually implemented by writing messages as -// JSON-RPC over the stdin/stdout of a subprocess. This file contains a JSON -// transport interface that handles this communication. -// -//===----------------------------------------------------------------------===// - -#ifndef MLIR_TOOLS_LSPSERVERSUPPORT_TRANSPORT_H -#define MLIR_TOOLS_LSPSERVERSUPPORT_TRANSPORT_H - -#include "mlir/Support/DebugStringHelper.h" -#include "mlir/Support/LLVM.h" -#include "mlir/Tools/lsp-server-support/Logging.h" -#include "mlir/Tools/lsp-server-support/Protocol.h" -#include "llvm/ADT/FunctionExtras.h" -#include "llvm/ADT/StringMap.h" -#include "llvm/ADT/StringRef.h" -#include "llvm/Support/FormatAdapters.h" -#include "llvm/Support/JSON.h" -#include "llvm/Support/raw_ostream.h" -#include - -namespace mlir { -namespace lsp { -class MessageHandler; - -//===----------------------------------------------------------------------===// -// JSONTransport -//===----------------------------------------------------------------------===// - -/// The encoding style of the JSON-RPC messages (both input and output). -enum JSONStreamStyle { - /// Encoding per the LSP specification, with mandatory Content-Length header. - Standard, - /// Messages are delimited by a '// -----' line. Comment lines start with //. - Delimited -}; - -/// An abstract class used by the JSONTransport to read JSON message. -class JSONTransportInput { -public: - explicit JSONTransportInput(JSONStreamStyle style = JSONStreamStyle::Standard) - : style(style) {} - virtual ~JSONTransportInput() = default; - - virtual bool hasError() const = 0; - virtual bool isEndOfInput() const = 0; - - /// Read in a message from the input stream. - LogicalResult readMessage(std::string &json) { - return style == JSONStreamStyle::Delimited ? readDelimitedMessage(json) - : readStandardMessage(json); - } - virtual LogicalResult readDelimitedMessage(std::string &json) = 0; - virtual LogicalResult readStandardMessage(std::string &json) = 0; - -private: - /// The JSON stream style to use. - JSONStreamStyle style; -}; - -/// Concrete implementation of the JSONTransportInput that reads from a file. -class JSONTransportInputOverFile : public JSONTransportInput { -public: - explicit JSONTransportInputOverFile( - std::FILE *in, JSONStreamStyle style = JSONStreamStyle::Standard) - : JSONTransportInput(style), in(in) {} - - bool hasError() const final { return ferror(in); } - bool isEndOfInput() const final { return feof(in); } - - LogicalResult readDelimitedMessage(std::string &json) final; - LogicalResult readStandardMessage(std::string &json) final; - -private: - std::FILE *in; -}; - -/// A transport class that performs the JSON-RPC communication with the LSP -/// client. -class JSONTransport { -public: - JSONTransport(std::unique_ptr in, raw_ostream &out, - bool prettyOutput = false) - : in(std::move(in)), out(out), prettyOutput(prettyOutput) {} - - JSONTransport(std::FILE *in, raw_ostream &out, - JSONStreamStyle style = JSONStreamStyle::Standard, - bool prettyOutput = false) - : in(std::make_unique(in, style)), out(out), - prettyOutput(prettyOutput) {} - - /// The following methods are used to send a message to the LSP client. - void notify(StringRef method, llvm::json::Value params); - void call(StringRef method, llvm::json::Value params, llvm::json::Value id); - void reply(llvm::json::Value id, llvm::Expected result); - - /// Start executing the JSON-RPC transport. - llvm::Error run(MessageHandler &handler); - -private: - /// Dispatches the given incoming json message to the message handler. - bool handleMessage(llvm::json::Value msg, MessageHandler &handler); - /// Writes the given message to the output stream. - void sendMessage(llvm::json::Value msg); - -private: - /// The input to read a message from. - std::unique_ptr in; - SmallVector outputBuffer; - /// The output file stream. - raw_ostream &out; - /// If the output JSON should be formatted for easier readability. - bool prettyOutput; -}; - -//===----------------------------------------------------------------------===// -// MessageHandler -//===----------------------------------------------------------------------===// - -/// A Callback is a void function that accepts Expected. This is -/// accepted by functions that logically return T. -template -using Callback = llvm::unique_function)>; - -/// An OutgoingNotification is a function used for outgoing notifications -/// send to the client. -template -using OutgoingNotification = llvm::unique_function; - -/// An OutgoingRequest is a function used for outgoing requests to send to -/// the client. -template -using OutgoingRequest = - llvm::unique_function; - -/// An `OutgoingRequestCallback` is invoked when an outgoing request to the -/// client receives a response in turn. It is passed the original request's ID, -/// as well as the response result. -template -using OutgoingRequestCallback = - std::function)>; - -/// A handler used to process the incoming transport messages. -class MessageHandler { -public: - MessageHandler(JSONTransport &transport) : transport(transport) {} - - bool onNotify(StringRef method, llvm::json::Value value); - bool onCall(StringRef method, llvm::json::Value params, llvm::json::Value id); - bool onReply(llvm::json::Value id, llvm::Expected result); - - template - static llvm::Expected parse(const llvm::json::Value &raw, - StringRef payloadName, StringRef payloadKind) { - T result; - llvm::json::Path::Root root; - if (fromJSON(raw, result, root)) - return std::move(result); - - // Dump the relevant parts of the broken message. - std::string context; - llvm::raw_string_ostream os(context); - root.printErrorContext(raw, os); - - // Report the error (e.g. to the client). - return llvm::make_error( - llvm::formatv("failed to decode {0} {1}: {2}", payloadName, payloadKind, - fmt_consume(root.getError())), - ErrorCode::InvalidParams); - } - - template - void method(llvm::StringLiteral method, ThisT *thisPtr, - void (ThisT::*handler)(const Param &, Callback)) { - methodHandlers[method] = [method, handler, - thisPtr](llvm::json::Value rawParams, - Callback reply) { - llvm::Expected param = parse(rawParams, method, "request"); - if (!param) - return reply(param.takeError()); - (thisPtr->*handler)(*param, std::move(reply)); - }; - } - - template - void notification(llvm::StringLiteral method, ThisT *thisPtr, - void (ThisT::*handler)(const Param &)) { - notificationHandlers[method] = [method, handler, - thisPtr](llvm::json::Value rawParams) { - llvm::Expected param = - parse(rawParams, method, "notification"); - if (!param) { - return llvm::consumeError( - llvm::handleErrors(param.takeError(), [](const LSPError &lspError) { - Logger::error("JSON parsing error: {0}", - lspError.message.c_str()); - })); - } - (thisPtr->*handler)(*param); - }; - } - - /// Create an OutgoingNotification object used for the given method. - template - OutgoingNotification outgoingNotification(llvm::StringLiteral method) { - return [&, method](const T ¶ms) { - std::lock_guard transportLock(transportOutputMutex); - Logger::info("--> {0}", method); - transport.notify(method, llvm::json::Value(params)); - }; - } - - /// Create an OutgoingRequest function that, when called, sends a request with - /// the given method via the transport. Should the outgoing request be - /// met with a response, the result JSON is parsed and the response callback - /// is invoked. - template - OutgoingRequest - outgoingRequest(llvm::StringLiteral method, - OutgoingRequestCallback callback) { - return [&, method, callback](const Param ¶m, llvm::json::Value id) { - auto callbackWrapper = [method, callback = std::move(callback)]( - llvm::json::Value id, - llvm::Expected value) { - if (!value) - return callback(std::move(id), value.takeError()); - - std::string responseName = llvm::formatv("reply:{0}({1})", method, id); - llvm::Expected result = - parse(*value, responseName, "response"); - if (!result) - return callback(std::move(id), result.takeError()); - - return callback(std::move(id), *result); - }; - - { - std::lock_guard lock(responseHandlersMutex); - responseHandlers.insert( - {debugString(id), std::make_pair(method.str(), callbackWrapper)}); - } - - std::lock_guard transportLock(transportOutputMutex); - Logger::info("--> {0}({1})", method, id); - transport.call(method, llvm::json::Value(param), id); - }; - } - -private: - template - using HandlerMap = llvm::StringMap>; - - HandlerMap notificationHandlers; - HandlerMap)> - methodHandlers; - - /// A pair of (1) the original request's method name, and (2) the callback - /// function to be invoked for responses. - using ResponseHandlerTy = - std::pair>; - /// A mapping from request/response ID to response handler. - llvm::StringMap responseHandlers; - /// Mutex to guard insertion into the response handler map. - std::mutex responseHandlersMutex; - - JSONTransport &transport; - - /// Mutex to guard sending output messages to the transport. - std::mutex transportOutputMutex; -}; - -} // namespace lsp -} // namespace mlir - -#endif diff --git a/mlir/include/mlir/Tools/mlir-lsp-server/MlirLspRegistryFunction.h b/mlir/include/mlir/Tools/mlir-lsp-server/MlirLspRegistryFunction.h index 4811ecb5e92b7..0d9ba2a0d1607 100644 --- a/mlir/include/mlir/Tools/mlir-lsp-server/MlirLspRegistryFunction.h +++ b/mlir/include/mlir/Tools/mlir-lsp-server/MlirLspRegistryFunction.h @@ -16,14 +16,16 @@ namespace llvm { template class function_ref; +namespace lsp { +class URIForFile; +} // namespace lsp } // namespace llvm namespace mlir { class DialectRegistry; namespace lsp { -class URIForFile; using DialectRegistryFn = - llvm::function_ref; + llvm::function_ref; } // namespace lsp } // namespace mlir diff --git a/mlir/lib/Tools/lsp-server-support/CMakeLists.txt b/mlir/lib/Tools/lsp-server-support/CMakeLists.txt index 48a96016b792f..2fe29f1b9ec41 100644 --- a/mlir/lib/Tools/lsp-server-support/CMakeLists.txt +++ b/mlir/lib/Tools/lsp-server-support/CMakeLists.txt @@ -1,13 +1,13 @@ add_mlir_library(MLIRLspServerSupportLib CompilationDatabase.cpp - Logging.cpp - Protocol.cpp SourceMgrUtils.cpp - Transport.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Tools/lsp-server-support + LINK_COMPONENTS + SupportLSP + LINK_LIBS PUBLIC MLIRSupport - ) +) diff --git a/mlir/lib/Tools/lsp-server-support/CompilationDatabase.cpp b/mlir/lib/Tools/lsp-server-support/CompilationDatabase.cpp index 9ae0674383a1d..67b8ef6a256bb 100644 --- a/mlir/lib/Tools/lsp-server-support/CompilationDatabase.cpp +++ b/mlir/lib/Tools/lsp-server-support/CompilationDatabase.cpp @@ -8,14 +8,15 @@ #include "mlir/Tools/lsp-server-support/CompilationDatabase.h" #include "mlir/Support/FileUtilities.h" -#include "mlir/Tools/lsp-server-support/Logging.h" -#include "mlir/Tools/lsp-server-support/Protocol.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/StringRef.h" +#include "llvm/Support/LSP/Logging.h" +#include "llvm/Support/LSP/Protocol.h" #include "llvm/Support/YAMLTraits.h" using namespace mlir; using namespace mlir::lsp; +using llvm::lsp::Logger; //===----------------------------------------------------------------------===// // YamlFileInfo diff --git a/mlir/lib/Tools/lsp-server-support/Protocol.cpp b/mlir/lib/Tools/lsp-server-support/Protocol.cpp deleted file mode 100644 index 98287048355c1..0000000000000 --- a/mlir/lib/Tools/lsp-server-support/Protocol.cpp +++ /dev/null @@ -1,1043 +0,0 @@ -//===--- Protocol.cpp - Language Server Protocol Implementation -----------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file contains the serialization code for the LSP structs. -// -//===----------------------------------------------------------------------===// - -#include "mlir/Tools/lsp-server-support/Protocol.h" -#include "llvm/ADT/StringExtras.h" -#include "llvm/ADT/StringSet.h" -#include "llvm/Support/ErrorHandling.h" -#include "llvm/Support/JSON.h" -#include "llvm/Support/MemoryBuffer.h" -#include "llvm/Support/Path.h" -#include "llvm/Support/raw_ostream.h" - -using namespace mlir; -using namespace mlir::lsp; - -// Helper that doesn't treat `null` and absent fields as failures. -template -static bool mapOptOrNull(const llvm::json::Value ¶ms, - llvm::StringLiteral prop, T &out, - llvm::json::Path path) { - const llvm::json::Object *o = params.getAsObject(); - assert(o); - - // Field is missing or null. - auto *v = o->get(prop); - if (!v || v->getAsNull()) - return true; - return fromJSON(*v, out, path.field(prop)); -} - -//===----------------------------------------------------------------------===// -// LSPError -//===----------------------------------------------------------------------===// - -char LSPError::ID; - -//===----------------------------------------------------------------------===// -// URIForFile -//===----------------------------------------------------------------------===// - -static bool isWindowsPath(StringRef path) { - return path.size() > 1 && llvm::isAlpha(path[0]) && path[1] == ':'; -} - -static bool isNetworkPath(StringRef path) { - return path.size() > 2 && path[0] == path[1] && - llvm::sys::path::is_separator(path[0]); -} - -static bool shouldEscapeInURI(unsigned char c) { - // Unreserved characters. - if ((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || - (c >= '0' && c <= '9')) - return false; - - switch (c) { - case '-': - case '_': - case '.': - case '~': - // '/' is only reserved when parsing. - case '/': - // ':' is only reserved for relative URI paths, which we doesn't produce. - case ':': - return false; - } - return true; -} - -/// Encodes a string according to percent-encoding. -/// - Unreserved characters are not escaped. -/// - Reserved characters always escaped with exceptions like '/'. -/// - All other characters are escaped. -static void percentEncode(StringRef content, std::string &out) { - for (unsigned char c : content) { - if (shouldEscapeInURI(c)) { - out.push_back('%'); - out.push_back(llvm::hexdigit(c / 16)); - out.push_back(llvm::hexdigit(c % 16)); - } else { - out.push_back(c); - } - } -} - -/// Decodes a string according to percent-encoding. -static std::string percentDecode(StringRef content) { - std::string result; - for (auto i = content.begin(), e = content.end(); i != e; ++i) { - if (*i != '%') { - result += *i; - continue; - } - if (*i == '%' && i + 2 < content.end() && llvm::isHexDigit(*(i + 1)) && - llvm::isHexDigit(*(i + 2))) { - result.push_back(llvm::hexFromNibbles(*(i + 1), *(i + 2))); - i += 2; - } else { - result.push_back(*i); - } - } - return result; -} - -/// Return the set containing the supported URI schemes. -static StringSet<> &getSupportedSchemes() { - static StringSet<> schemes({"file", "test"}); - return schemes; -} - -/// Returns true if the given scheme is structurally valid, i.e. it does not -/// contain any invalid scheme characters. This does not check that the scheme -/// is actually supported. -static bool isStructurallyValidScheme(StringRef scheme) { - if (scheme.empty()) - return false; - if (!llvm::isAlpha(scheme[0])) - return false; - return llvm::all_of(llvm::drop_begin(scheme), [](char c) { - return llvm::isAlnum(c) || c == '+' || c == '.' || c == '-'; - }); -} - -static llvm::Expected uriFromAbsolutePath(StringRef absolutePath, - StringRef scheme) { - std::string body; - StringRef authority; - StringRef root = llvm::sys::path::root_name(absolutePath); - if (isNetworkPath(root)) { - // Windows UNC paths e.g. \\server\share => file://server/share - authority = root.drop_front(2); - absolutePath.consume_front(root); - } else if (isWindowsPath(root)) { - // Windows paths e.g. X:\path => file:///X:/path - body = "/"; - } - body += llvm::sys::path::convert_to_slash(absolutePath); - - std::string uri = scheme.str() + ":"; - if (authority.empty() && body.empty()) - return uri; - - // If authority if empty, we only print body if it starts with "/"; otherwise, - // the URI is invalid. - if (!authority.empty() || StringRef(body).starts_with("/")) { - uri.append("//"); - percentEncode(authority, uri); - } - percentEncode(body, uri); - return uri; -} - -static llvm::Expected getAbsolutePath(StringRef authority, - StringRef body) { - if (!body.starts_with("/")) - return llvm::createStringError( - llvm::inconvertibleErrorCode(), - "File scheme: expect body to be an absolute path starting " - "with '/': " + - body); - SmallString<128> path; - if (!authority.empty()) { - // Windows UNC paths e.g. file://server/share => \\server\share - ("//" + authority).toVector(path); - } else if (isWindowsPath(body.substr(1))) { - // Windows paths e.g. file:///X:/path => X:\path - body.consume_front("/"); - } - path.append(body); - llvm::sys::path::native(path); - return std::string(path); -} - -static llvm::Expected parseFilePathFromURI(StringRef origUri) { - StringRef uri = origUri; - - // Decode the scheme of the URI. - size_t pos = uri.find(':'); - if (pos == StringRef::npos) - return llvm::createStringError(llvm::inconvertibleErrorCode(), - "Scheme must be provided in URI: " + - origUri); - StringRef schemeStr = uri.substr(0, pos); - std::string uriScheme = percentDecode(schemeStr); - if (!isStructurallyValidScheme(uriScheme)) - return llvm::createStringError(llvm::inconvertibleErrorCode(), - "Invalid scheme: " + schemeStr + - " (decoded: " + uriScheme + ")"); - uri = uri.substr(pos + 1); - - // Decode the authority of the URI. - std::string uriAuthority; - if (uri.consume_front("//")) { - pos = uri.find('/'); - uriAuthority = percentDecode(uri.substr(0, pos)); - uri = uri.substr(pos); - } - - // Decode the body of the URI. - std::string uriBody = percentDecode(uri); - - // Compute the absolute path for this uri. - if (!getSupportedSchemes().contains(uriScheme)) { - return llvm::createStringError(llvm::inconvertibleErrorCode(), - "unsupported URI scheme `" + uriScheme + - "' for workspace files"); - } - return getAbsolutePath(uriAuthority, uriBody); -} - -llvm::Expected URIForFile::fromURI(StringRef uri) { - llvm::Expected filePath = parseFilePathFromURI(uri); - if (!filePath) - return filePath.takeError(); - return URIForFile(std::move(*filePath), uri.str()); -} - -llvm::Expected URIForFile::fromFile(StringRef absoluteFilepath, - StringRef scheme) { - llvm::Expected uri = - uriFromAbsolutePath(absoluteFilepath, scheme); - if (!uri) - return uri.takeError(); - return fromURI(*uri); -} - -StringRef URIForFile::scheme() const { return uri().split(':').first; } - -void URIForFile::registerSupportedScheme(StringRef scheme) { - getSupportedSchemes().insert(scheme); -} - -bool mlir::lsp::fromJSON(const llvm::json::Value &value, URIForFile &result, - llvm::json::Path path) { - if (std::optional str = value.getAsString()) { - llvm::Expected expectedURI = URIForFile::fromURI(*str); - if (!expectedURI) { - path.report("unresolvable URI"); - consumeError(expectedURI.takeError()); - return false; - } - result = std::move(*expectedURI); - return true; - } - return false; -} - -llvm::json::Value mlir::lsp::toJSON(const URIForFile &value) { - return value.uri(); -} - -raw_ostream &mlir::lsp::operator<<(raw_ostream &os, const URIForFile &value) { - return os << value.uri(); -} - -//===----------------------------------------------------------------------===// -// ClientCapabilities -//===----------------------------------------------------------------------===// - -bool mlir::lsp::fromJSON(const llvm::json::Value &value, - ClientCapabilities &result, llvm::json::Path path) { - const llvm::json::Object *o = value.getAsObject(); - if (!o) { - path.report("expected object"); - return false; - } - if (const llvm::json::Object *textDocument = o->getObject("textDocument")) { - if (const llvm::json::Object *documentSymbol = - textDocument->getObject("documentSymbol")) { - if (std::optional hierarchicalSupport = - documentSymbol->getBoolean("hierarchicalDocumentSymbolSupport")) - result.hierarchicalDocumentSymbol = *hierarchicalSupport; - } - if (auto *codeAction = textDocument->getObject("codeAction")) { - if (codeAction->getObject("codeActionLiteralSupport")) - result.codeActionStructure = true; - } - } - if (auto *window = o->getObject("window")) { - if (std::optional workDoneProgressSupport = - window->getBoolean("workDoneProgress")) - result.workDoneProgress = *workDoneProgressSupport; - } - return true; -} - -//===----------------------------------------------------------------------===// -// ClientInfo -//===----------------------------------------------------------------------===// - -bool mlir::lsp::fromJSON(const llvm::json::Value &value, ClientInfo &result, - llvm::json::Path path) { - llvm::json::ObjectMapper o(value, path); - if (!o || !o.map("name", result.name)) - return false; - - // Don't fail if we can't parse version. - o.map("version", result.version); - return true; -} - -//===----------------------------------------------------------------------===// -// InitializeParams -//===----------------------------------------------------------------------===// - -bool mlir::lsp::fromJSON(const llvm::json::Value &value, TraceLevel &result, - llvm::json::Path path) { - if (std::optional str = value.getAsString()) { - if (*str == "off") { - result = TraceLevel::Off; - return true; - } - if (*str == "messages") { - result = TraceLevel::Messages; - return true; - } - if (*str == "verbose") { - result = TraceLevel::Verbose; - return true; - } - } - return false; -} - -bool mlir::lsp::fromJSON(const llvm::json::Value &value, - InitializeParams &result, llvm::json::Path path) { - llvm::json::ObjectMapper o(value, path); - if (!o) - return false; - // We deliberately don't fail if we can't parse individual fields. - o.map("capabilities", result.capabilities); - o.map("trace", result.trace); - mapOptOrNull(value, "clientInfo", result.clientInfo, path); - - return true; -} - -//===----------------------------------------------------------------------===// -// TextDocumentItem -//===----------------------------------------------------------------------===// - -bool mlir::lsp::fromJSON(const llvm::json::Value &value, - TextDocumentItem &result, llvm::json::Path path) { - llvm::json::ObjectMapper o(value, path); - return o && o.map("uri", result.uri) && - o.map("languageId", result.languageId) && o.map("text", result.text) && - o.map("version", result.version); -} - -//===----------------------------------------------------------------------===// -// TextDocumentIdentifier -//===----------------------------------------------------------------------===// - -llvm::json::Value mlir::lsp::toJSON(const TextDocumentIdentifier &value) { - return llvm::json::Object{{"uri", value.uri}}; -} - -bool mlir::lsp::fromJSON(const llvm::json::Value &value, - TextDocumentIdentifier &result, - llvm::json::Path path) { - llvm::json::ObjectMapper o(value, path); - return o && o.map("uri", result.uri); -} - -//===----------------------------------------------------------------------===// -// VersionedTextDocumentIdentifier -//===----------------------------------------------------------------------===// - -llvm::json::Value -mlir::lsp::toJSON(const VersionedTextDocumentIdentifier &value) { - return llvm::json::Object{ - {"uri", value.uri}, - {"version", value.version}, - }; -} - -bool mlir::lsp::fromJSON(const llvm::json::Value &value, - VersionedTextDocumentIdentifier &result, - llvm::json::Path path) { - llvm::json::ObjectMapper o(value, path); - return o && o.map("uri", result.uri) && o.map("version", result.version); -} - -//===----------------------------------------------------------------------===// -// Position -//===----------------------------------------------------------------------===// - -bool mlir::lsp::fromJSON(const llvm::json::Value &value, Position &result, - llvm::json::Path path) { - llvm::json::ObjectMapper o(value, path); - return o && o.map("line", result.line) && - o.map("character", result.character); -} - -llvm::json::Value mlir::lsp::toJSON(const Position &value) { - return llvm::json::Object{ - {"line", value.line}, - {"character", value.character}, - }; -} - -raw_ostream &mlir::lsp::operator<<(raw_ostream &os, const Position &value) { - return os << value.line << ':' << value.character; -} - -//===----------------------------------------------------------------------===// -// Range -//===----------------------------------------------------------------------===// - -bool mlir::lsp::fromJSON(const llvm::json::Value &value, Range &result, - llvm::json::Path path) { - llvm::json::ObjectMapper o(value, path); - return o && o.map("start", result.start) && o.map("end", result.end); -} - -llvm::json::Value mlir::lsp::toJSON(const Range &value) { - return llvm::json::Object{ - {"start", value.start}, - {"end", value.end}, - }; -} - -raw_ostream &mlir::lsp::operator<<(raw_ostream &os, const Range &value) { - return os << value.start << '-' << value.end; -} - -//===----------------------------------------------------------------------===// -// Location -//===----------------------------------------------------------------------===// - -bool mlir::lsp::fromJSON(const llvm::json::Value &value, Location &result, - llvm::json::Path path) { - llvm::json::ObjectMapper o(value, path); - return o && o.map("uri", result.uri) && o.map("range", result.range); -} - -llvm::json::Value mlir::lsp::toJSON(const Location &value) { - return llvm::json::Object{ - {"uri", value.uri}, - {"range", value.range}, - }; -} - -raw_ostream &mlir::lsp::operator<<(raw_ostream &os, const Location &value) { - return os << value.range << '@' << value.uri; -} - -//===----------------------------------------------------------------------===// -// TextDocumentPositionParams -//===----------------------------------------------------------------------===// - -bool mlir::lsp::fromJSON(const llvm::json::Value &value, - TextDocumentPositionParams &result, - llvm::json::Path path) { - llvm::json::ObjectMapper o(value, path); - return o && o.map("textDocument", result.textDocument) && - o.map("position", result.position); -} - -//===----------------------------------------------------------------------===// -// ReferenceParams -//===----------------------------------------------------------------------===// - -bool mlir::lsp::fromJSON(const llvm::json::Value &value, - ReferenceContext &result, llvm::json::Path path) { - llvm::json::ObjectMapper o(value, path); - return o && o.mapOptional("includeDeclaration", result.includeDeclaration); -} - -bool mlir::lsp::fromJSON(const llvm::json::Value &value, - ReferenceParams &result, llvm::json::Path path) { - TextDocumentPositionParams &base = result; - llvm::json::ObjectMapper o(value, path); - return fromJSON(value, base, path) && o && - o.mapOptional("context", result.context); -} - -//===----------------------------------------------------------------------===// -// DidOpenTextDocumentParams -//===----------------------------------------------------------------------===// - -bool mlir::lsp::fromJSON(const llvm::json::Value &value, - DidOpenTextDocumentParams &result, - llvm::json::Path path) { - llvm::json::ObjectMapper o(value, path); - return o && o.map("textDocument", result.textDocument); -} - -//===----------------------------------------------------------------------===// -// DidCloseTextDocumentParams -//===----------------------------------------------------------------------===// - -bool mlir::lsp::fromJSON(const llvm::json::Value &value, - DidCloseTextDocumentParams &result, - llvm::json::Path path) { - llvm::json::ObjectMapper o(value, path); - return o && o.map("textDocument", result.textDocument); -} - -//===----------------------------------------------------------------------===// -// DidChangeTextDocumentParams -//===----------------------------------------------------------------------===// - -LogicalResult -TextDocumentContentChangeEvent::applyTo(std::string &contents) const { - // If there is no range, the full document changed. - if (!range) { - contents = text; - return success(); - } - - // Try to map the replacement range to the content. - llvm::SourceMgr tmpScrMgr; - tmpScrMgr.AddNewSourceBuffer(llvm::MemoryBuffer::getMemBuffer(contents), - SMLoc()); - SMRange rangeLoc = range->getAsSMRange(tmpScrMgr); - if (!rangeLoc.isValid()) - return failure(); - - contents.replace(rangeLoc.Start.getPointer() - contents.data(), - rangeLoc.End.getPointer() - rangeLoc.Start.getPointer(), - text); - return success(); -} - -LogicalResult TextDocumentContentChangeEvent::applyTo( - ArrayRef changes, std::string &contents) { - for (const auto &change : changes) - if (failed(change.applyTo(contents))) - return failure(); - return success(); -} - -bool mlir::lsp::fromJSON(const llvm::json::Value &value, - TextDocumentContentChangeEvent &result, - llvm::json::Path path) { - llvm::json::ObjectMapper o(value, path); - return o && o.map("range", result.range) && - o.map("rangeLength", result.rangeLength) && o.map("text", result.text); -} - -bool mlir::lsp::fromJSON(const llvm::json::Value &value, - DidChangeTextDocumentParams &result, - llvm::json::Path path) { - llvm::json::ObjectMapper o(value, path); - return o && o.map("textDocument", result.textDocument) && - o.map("contentChanges", result.contentChanges); -} - -//===----------------------------------------------------------------------===// -// MarkupContent -//===----------------------------------------------------------------------===// - -static llvm::StringRef toTextKind(MarkupKind kind) { - switch (kind) { - case MarkupKind::PlainText: - return "plaintext"; - case MarkupKind::Markdown: - return "markdown"; - } - llvm_unreachable("Invalid MarkupKind"); -} - -raw_ostream &mlir::lsp::operator<<(raw_ostream &os, MarkupKind kind) { - return os << toTextKind(kind); -} - -llvm::json::Value mlir::lsp::toJSON(const MarkupContent &mc) { - if (mc.value.empty()) - return nullptr; - - return llvm::json::Object{ - {"kind", toTextKind(mc.kind)}, - {"value", mc.value}, - }; -} - -//===----------------------------------------------------------------------===// -// Hover -//===----------------------------------------------------------------------===// - -llvm::json::Value mlir::lsp::toJSON(const Hover &hover) { - llvm::json::Object result{{"contents", toJSON(hover.contents)}}; - if (hover.range) - result["range"] = toJSON(*hover.range); - return std::move(result); -} - -//===----------------------------------------------------------------------===// -// DocumentSymbol -//===----------------------------------------------------------------------===// - -llvm::json::Value mlir::lsp::toJSON(const DocumentSymbol &symbol) { - llvm::json::Object result{{"name", symbol.name}, - {"kind", static_cast(symbol.kind)}, - {"range", symbol.range}, - {"selectionRange", symbol.selectionRange}}; - - if (!symbol.detail.empty()) - result["detail"] = symbol.detail; - if (!symbol.children.empty()) - result["children"] = symbol.children; - return std::move(result); -} - -//===----------------------------------------------------------------------===// -// DocumentSymbolParams -//===----------------------------------------------------------------------===// - -bool mlir::lsp::fromJSON(const llvm::json::Value &value, - DocumentSymbolParams &result, llvm::json::Path path) { - llvm::json::ObjectMapper o(value, path); - return o && o.map("textDocument", result.textDocument); -} - -//===----------------------------------------------------------------------===// -// DiagnosticRelatedInformation -//===----------------------------------------------------------------------===// - -bool mlir::lsp::fromJSON(const llvm::json::Value &value, - DiagnosticRelatedInformation &result, - llvm::json::Path path) { - llvm::json::ObjectMapper o(value, path); - return o && o.map("location", result.location) && - o.map("message", result.message); -} - -llvm::json::Value mlir::lsp::toJSON(const DiagnosticRelatedInformation &info) { - return llvm::json::Object{ - {"location", info.location}, - {"message", info.message}, - }; -} - -//===----------------------------------------------------------------------===// -// Diagnostic -//===----------------------------------------------------------------------===// - -llvm::json::Value mlir::lsp::toJSON(DiagnosticTag tag) { - return static_cast(tag); -} - -bool mlir::lsp::fromJSON(const llvm::json::Value &value, DiagnosticTag &result, - llvm::json::Path path) { - if (std::optional i = value.getAsInteger()) { - result = (DiagnosticTag)*i; - return true; - } - - return false; -} - -llvm::json::Value mlir::lsp::toJSON(const Diagnostic &diag) { - llvm::json::Object result{ - {"range", diag.range}, - {"severity", (int)diag.severity}, - {"message", diag.message}, - }; - if (diag.category) - result["category"] = *diag.category; - if (!diag.source.empty()) - result["source"] = diag.source; - if (diag.relatedInformation) - result["relatedInformation"] = *diag.relatedInformation; - if (!diag.tags.empty()) - result["tags"] = diag.tags; - return std::move(result); -} - -bool mlir::lsp::fromJSON(const llvm::json::Value &value, Diagnostic &result, - llvm::json::Path path) { - llvm::json::ObjectMapper o(value, path); - if (!o) - return false; - int severity = 0; - if (!mapOptOrNull(value, "severity", severity, path)) - return false; - result.severity = (DiagnosticSeverity)severity; - - return o.map("range", result.range) && o.map("message", result.message) && - mapOptOrNull(value, "category", result.category, path) && - mapOptOrNull(value, "source", result.source, path) && - mapOptOrNull(value, "relatedInformation", result.relatedInformation, - path) && - mapOptOrNull(value, "tags", result.tags, path); -} - -//===----------------------------------------------------------------------===// -// PublishDiagnosticsParams -//===----------------------------------------------------------------------===// - -llvm::json::Value mlir::lsp::toJSON(const PublishDiagnosticsParams ¶ms) { - return llvm::json::Object{ - {"uri", params.uri}, - {"diagnostics", params.diagnostics}, - {"version", params.version}, - }; -} - -//===----------------------------------------------------------------------===// -// TextEdit -//===----------------------------------------------------------------------===// - -bool mlir::lsp::fromJSON(const llvm::json::Value &value, TextEdit &result, - llvm::json::Path path) { - llvm::json::ObjectMapper o(value, path); - return o && o.map("range", result.range) && o.map("newText", result.newText); -} - -llvm::json::Value mlir::lsp::toJSON(const TextEdit &value) { - return llvm::json::Object{ - {"range", value.range}, - {"newText", value.newText}, - }; -} - -raw_ostream &mlir::lsp::operator<<(raw_ostream &os, const TextEdit &value) { - os << value.range << " => \""; - llvm::printEscapedString(value.newText, os); - return os << '"'; -} - -//===----------------------------------------------------------------------===// -// CompletionItemKind -//===----------------------------------------------------------------------===// - -bool mlir::lsp::fromJSON(const llvm::json::Value &value, - CompletionItemKind &result, llvm::json::Path path) { - if (std::optional intValue = value.getAsInteger()) { - if (*intValue < static_cast(CompletionItemKind::Text) || - *intValue > static_cast(CompletionItemKind::TypeParameter)) - return false; - result = static_cast(*intValue); - return true; - } - return false; -} - -CompletionItemKind mlir::lsp::adjustKindToCapability( - CompletionItemKind kind, - CompletionItemKindBitset &supportedCompletionItemKinds) { - size_t kindVal = static_cast(kind); - if (kindVal >= kCompletionItemKindMin && - kindVal <= supportedCompletionItemKinds.size() && - supportedCompletionItemKinds[kindVal]) - return kind; - - // Provide some fall backs for common kinds that are close enough. - switch (kind) { - case CompletionItemKind::Folder: - return CompletionItemKind::File; - case CompletionItemKind::EnumMember: - return CompletionItemKind::Enum; - case CompletionItemKind::Struct: - return CompletionItemKind::Class; - default: - return CompletionItemKind::Text; - } -} - -bool mlir::lsp::fromJSON(const llvm::json::Value &value, - CompletionItemKindBitset &result, - llvm::json::Path path) { - if (const llvm::json::Array *arrayValue = value.getAsArray()) { - for (size_t i = 0, e = arrayValue->size(); i < e; ++i) { - CompletionItemKind kindOut; - if (fromJSON((*arrayValue)[i], kindOut, path.index(i))) - result.set(size_t(kindOut)); - } - return true; - } - return false; -} - -//===----------------------------------------------------------------------===// -// CompletionItem -//===----------------------------------------------------------------------===// - -llvm::json::Value mlir::lsp::toJSON(const CompletionItem &value) { - assert(!value.label.empty() && "completion item label is required"); - llvm::json::Object result{{"label", value.label}}; - if (value.kind != CompletionItemKind::Missing) - result["kind"] = static_cast(value.kind); - if (!value.detail.empty()) - result["detail"] = value.detail; - if (value.documentation) - result["documentation"] = value.documentation; - if (!value.sortText.empty()) - result["sortText"] = value.sortText; - if (!value.filterText.empty()) - result["filterText"] = value.filterText; - if (!value.insertText.empty()) - result["insertText"] = value.insertText; - if (value.insertTextFormat != InsertTextFormat::Missing) - result["insertTextFormat"] = static_cast(value.insertTextFormat); - if (value.textEdit) - result["textEdit"] = *value.textEdit; - if (!value.additionalTextEdits.empty()) { - result["additionalTextEdits"] = - llvm::json::Array(value.additionalTextEdits); - } - if (value.deprecated) - result["deprecated"] = value.deprecated; - return std::move(result); -} - -raw_ostream &mlir::lsp::operator<<(raw_ostream &os, - const CompletionItem &value) { - return os << value.label << " - " << toJSON(value); -} - -bool mlir::lsp::operator<(const CompletionItem &lhs, - const CompletionItem &rhs) { - return (lhs.sortText.empty() ? lhs.label : lhs.sortText) < - (rhs.sortText.empty() ? rhs.label : rhs.sortText); -} - -//===----------------------------------------------------------------------===// -// CompletionList -//===----------------------------------------------------------------------===// - -llvm::json::Value mlir::lsp::toJSON(const CompletionList &value) { - return llvm::json::Object{ - {"isIncomplete", value.isIncomplete}, - {"items", llvm::json::Array(value.items)}, - }; -} - -//===----------------------------------------------------------------------===// -// CompletionContext -//===----------------------------------------------------------------------===// - -bool mlir::lsp::fromJSON(const llvm::json::Value &value, - CompletionContext &result, llvm::json::Path path) { - llvm::json::ObjectMapper o(value, path); - int triggerKind; - if (!o || !o.map("triggerKind", triggerKind) || - !mapOptOrNull(value, "triggerCharacter", result.triggerCharacter, path)) - return false; - result.triggerKind = static_cast(triggerKind); - return true; -} - -//===----------------------------------------------------------------------===// -// CompletionParams -//===----------------------------------------------------------------------===// - -bool mlir::lsp::fromJSON(const llvm::json::Value &value, - CompletionParams &result, llvm::json::Path path) { - if (!fromJSON(value, static_cast(result), path)) - return false; - if (const llvm::json::Value *context = value.getAsObject()->get("context")) - return fromJSON(*context, result.context, path.field("context")); - return true; -} - -//===----------------------------------------------------------------------===// -// ParameterInformation -//===----------------------------------------------------------------------===// - -llvm::json::Value mlir::lsp::toJSON(const ParameterInformation &value) { - assert((value.labelOffsets || !value.labelString.empty()) && - "parameter information label is required"); - llvm::json::Object result; - if (value.labelOffsets) - result["label"] = llvm::json::Array( - {value.labelOffsets->first, value.labelOffsets->second}); - else - result["label"] = value.labelString; - if (!value.documentation.empty()) - result["documentation"] = value.documentation; - return std::move(result); -} - -//===----------------------------------------------------------------------===// -// SignatureInformation -//===----------------------------------------------------------------------===// - -llvm::json::Value mlir::lsp::toJSON(const SignatureInformation &value) { - assert(!value.label.empty() && "signature information label is required"); - llvm::json::Object result{ - {"label", value.label}, - {"parameters", llvm::json::Array(value.parameters)}, - }; - if (!value.documentation.empty()) - result["documentation"] = value.documentation; - return std::move(result); -} - -raw_ostream &mlir::lsp::operator<<(raw_ostream &os, - const SignatureInformation &value) { - return os << value.label << " - " << toJSON(value); -} - -//===----------------------------------------------------------------------===// -// SignatureHelp -//===----------------------------------------------------------------------===// - -llvm::json::Value mlir::lsp::toJSON(const SignatureHelp &value) { - assert(value.activeSignature >= 0 && - "Unexpected negative value for number of active signatures."); - assert(value.activeParameter >= 0 && - "Unexpected negative value for active parameter index"); - return llvm::json::Object{ - {"activeSignature", value.activeSignature}, - {"activeParameter", value.activeParameter}, - {"signatures", llvm::json::Array(value.signatures)}, - }; -} - -//===----------------------------------------------------------------------===// -// DocumentLinkParams -//===----------------------------------------------------------------------===// - -bool mlir::lsp::fromJSON(const llvm::json::Value &value, - DocumentLinkParams &result, llvm::json::Path path) { - llvm::json::ObjectMapper o(value, path); - return o && o.map("textDocument", result.textDocument); -} - -//===----------------------------------------------------------------------===// -// DocumentLink -//===----------------------------------------------------------------------===// - -llvm::json::Value mlir::lsp::toJSON(const DocumentLink &value) { - return llvm::json::Object{ - {"range", value.range}, - {"target", value.target}, - }; -} - -//===----------------------------------------------------------------------===// -// InlayHintsParams -//===----------------------------------------------------------------------===// - -bool mlir::lsp::fromJSON(const llvm::json::Value &value, - InlayHintsParams &result, llvm::json::Path path) { - llvm::json::ObjectMapper o(value, path); - return o && o.map("textDocument", result.textDocument) && - o.map("range", result.range); -} - -//===----------------------------------------------------------------------===// -// InlayHint -//===----------------------------------------------------------------------===// - -llvm::json::Value mlir::lsp::toJSON(const InlayHint &value) { - return llvm::json::Object{{"position", value.position}, - {"kind", (int)value.kind}, - {"label", value.label}, - {"paddingLeft", value.paddingLeft}, - {"paddingRight", value.paddingRight}}; -} -bool mlir::lsp::operator==(const InlayHint &lhs, const InlayHint &rhs) { - return std::tie(lhs.position, lhs.kind, lhs.label) == - std::tie(rhs.position, rhs.kind, rhs.label); -} -bool mlir::lsp::operator<(const InlayHint &lhs, const InlayHint &rhs) { - return std::tie(lhs.position, lhs.kind, lhs.label) < - std::tie(rhs.position, rhs.kind, rhs.label); -} - -llvm::raw_ostream &mlir::lsp::operator<<(llvm::raw_ostream &os, - InlayHintKind value) { - switch (value) { - case InlayHintKind::Parameter: - return os << "parameter"; - case InlayHintKind::Type: - return os << "type"; - } - llvm_unreachable("Unknown InlayHintKind"); -} - -//===----------------------------------------------------------------------===// -// CodeActionContext -//===----------------------------------------------------------------------===// - -bool mlir::lsp::fromJSON(const llvm::json::Value &value, - CodeActionContext &result, llvm::json::Path path) { - llvm::json::ObjectMapper o(value, path); - if (!o || !o.map("diagnostics", result.diagnostics)) - return false; - o.map("only", result.only); - return true; -} - -//===----------------------------------------------------------------------===// -// CodeActionParams -//===----------------------------------------------------------------------===// - -bool mlir::lsp::fromJSON(const llvm::json::Value &value, - CodeActionParams &result, llvm::json::Path path) { - llvm::json::ObjectMapper o(value, path); - return o && o.map("textDocument", result.textDocument) && - o.map("range", result.range) && o.map("context", result.context); -} - -//===----------------------------------------------------------------------===// -// WorkspaceEdit -//===----------------------------------------------------------------------===// - -bool mlir::lsp::fromJSON(const llvm::json::Value &value, WorkspaceEdit &result, - llvm::json::Path path) { - llvm::json::ObjectMapper o(value, path); - return o && o.map("changes", result.changes); -} - -llvm::json::Value mlir::lsp::toJSON(const WorkspaceEdit &value) { - llvm::json::Object fileChanges; - for (auto &change : value.changes) - fileChanges[change.first] = llvm::json::Array(change.second); - return llvm::json::Object{{"changes", std::move(fileChanges)}}; -} - -//===----------------------------------------------------------------------===// -// CodeAction -//===----------------------------------------------------------------------===// - -const llvm::StringLiteral CodeAction::kQuickFix = "quickfix"; -const llvm::StringLiteral CodeAction::kRefactor = "refactor"; -const llvm::StringLiteral CodeAction::kInfo = "info"; - -llvm::json::Value mlir::lsp::toJSON(const CodeAction &value) { - llvm::json::Object codeAction{{"title", value.title}}; - if (value.kind) - codeAction["kind"] = *value.kind; - if (value.diagnostics) - codeAction["diagnostics"] = llvm::json::Array(*value.diagnostics); - if (value.isPreferred) - codeAction["isPreferred"] = true; - if (value.edit) - codeAction["edit"] = *value.edit; - return std::move(codeAction); -} diff --git a/mlir/lib/Tools/lsp-server-support/SourceMgrUtils.cpp b/mlir/lib/Tools/lsp-server-support/SourceMgrUtils.cpp index f1a362385f285..5cd1c85d054ab 100644 --- a/mlir/lib/Tools/lsp-server-support/SourceMgrUtils.cpp +++ b/mlir/lib/Tools/lsp-server-support/SourceMgrUtils.cpp @@ -14,6 +14,10 @@ using namespace mlir; using namespace mlir::lsp; +using llvm::lsp::Hover; +using llvm::lsp::Range; +using llvm::lsp::URIForFile; + //===----------------------------------------------------------------------===// // Utils //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Tools/lsp-server-support/Transport.cpp b/mlir/lib/Tools/lsp-server-support/Transport.cpp deleted file mode 100644 index 5a098b2841f4b..0000000000000 --- a/mlir/lib/Tools/lsp-server-support/Transport.cpp +++ /dev/null @@ -1,369 +0,0 @@ -//===--- JSONTransport.cpp - sending and receiving LSP messages over JSON -===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "mlir/Tools/lsp-server-support/Transport.h" -#include "mlir/Support/ToolUtilities.h" -#include "mlir/Tools/lsp-server-support/Logging.h" -#include "mlir/Tools/lsp-server-support/Protocol.h" -#include "llvm/ADT/SmallString.h" -#include "llvm/Support/Error.h" -#include -#include -#include - -using namespace mlir; -using namespace mlir::lsp; - -//===----------------------------------------------------------------------===// -// Reply -//===----------------------------------------------------------------------===// - -namespace { -/// Function object to reply to an LSP call. -/// Each instance must be called exactly once, otherwise: -/// - if there was no reply, an error reply is sent -/// - if there were multiple replies, only the first is sent -class Reply { -public: - Reply(const llvm::json::Value &id, StringRef method, JSONTransport &transport, - std::mutex &transportOutputMutex); - Reply(Reply &&other); - Reply &operator=(Reply &&) = delete; - Reply(const Reply &) = delete; - Reply &operator=(const Reply &) = delete; - - void operator()(llvm::Expected reply); - -private: - std::string method; - std::atomic replied = {false}; - llvm::json::Value id; - JSONTransport *transport; - std::mutex &transportOutputMutex; -}; -} // namespace - -Reply::Reply(const llvm::json::Value &id, llvm::StringRef method, - JSONTransport &transport, std::mutex &transportOutputMutex) - : method(method), id(id), transport(&transport), - transportOutputMutex(transportOutputMutex) {} - -Reply::Reply(Reply &&other) - : method(other.method), replied(other.replied.load()), - id(std::move(other.id)), transport(other.transport), - transportOutputMutex(other.transportOutputMutex) { - other.transport = nullptr; -} - -void Reply::operator()(llvm::Expected reply) { - if (replied.exchange(true)) { - Logger::error("Replied twice to message {0}({1})", method, id); - assert(false && "must reply to each call only once!"); - return; - } - assert(transport && "expected valid transport to reply to"); - - std::lock_guard transportLock(transportOutputMutex); - if (reply) { - Logger::info("--> reply:{0}({1})", method, id); - transport->reply(std::move(id), std::move(reply)); - } else { - llvm::Error error = reply.takeError(); - Logger::info("--> reply:{0}({1}): {2}", method, id, error); - transport->reply(std::move(id), std::move(error)); - } -} - -//===----------------------------------------------------------------------===// -// MessageHandler -//===----------------------------------------------------------------------===// - -bool MessageHandler::onNotify(llvm::StringRef method, llvm::json::Value value) { - Logger::info("--> {0}", method); - - if (method == "exit") - return false; - if (method == "$cancel") { - // TODO: Add support for cancelling requests. - } else { - auto it = notificationHandlers.find(method); - if (it != notificationHandlers.end()) - it->second(std::move(value)); - } - return true; -} - -bool MessageHandler::onCall(llvm::StringRef method, llvm::json::Value params, - llvm::json::Value id) { - Logger::info("--> {0}({1})", method, id); - - Reply reply(id, method, transport, transportOutputMutex); - - auto it = methodHandlers.find(method); - if (it != methodHandlers.end()) { - it->second(std::move(params), std::move(reply)); - } else { - reply(llvm::make_error("method not found: " + method.str(), - ErrorCode::MethodNotFound)); - } - return true; -} - -bool MessageHandler::onReply(llvm::json::Value id, - llvm::Expected result) { - // Find the response handler in the mapping. If it exists, move it out of the - // mapping and erase it. - ResponseHandlerTy responseHandler; - { - std::lock_guard responseHandlersLock(responseHandlersMutex); - auto it = responseHandlers.find(debugString(id)); - if (it != responseHandlers.end()) { - responseHandler = std::move(it->second); - responseHandlers.erase(it); - } - } - - // If we found a response handler, invoke it. Otherwise, log an error. - if (responseHandler.second) { - Logger::info("--> reply:{0}({1})", responseHandler.first, id); - responseHandler.second(std::move(id), std::move(result)); - } else { - Logger::error( - "received a reply with ID {0}, but there was no such outgoing request", - id); - if (!result) - llvm::consumeError(result.takeError()); - } - return true; -} - -//===----------------------------------------------------------------------===// -// JSONTransport -//===----------------------------------------------------------------------===// - -/// Encode the given error as a JSON object. -static llvm::json::Object encodeError(llvm::Error error) { - std::string message; - ErrorCode code = ErrorCode::UnknownErrorCode; - auto handlerFn = [&](const LSPError &lspError) -> llvm::Error { - message = lspError.message; - code = lspError.code; - return llvm::Error::success(); - }; - if (llvm::Error unhandled = llvm::handleErrors(std::move(error), handlerFn)) - message = llvm::toString(std::move(unhandled)); - - return llvm::json::Object{ - {"message", std::move(message)}, - {"code", int64_t(code)}, - }; -} - -/// Decode the given JSON object into an error. -llvm::Error decodeError(const llvm::json::Object &o) { - StringRef msg = o.getString("message").value_or("Unspecified error"); - if (std::optional code = o.getInteger("code")) - return llvm::make_error(msg.str(), ErrorCode(*code)); - return llvm::make_error(llvm::inconvertibleErrorCode(), - msg.str()); -} - -void JSONTransport::notify(StringRef method, llvm::json::Value params) { - sendMessage(llvm::json::Object{ - {"jsonrpc", "2.0"}, - {"method", method}, - {"params", std::move(params)}, - }); -} -void JSONTransport::call(StringRef method, llvm::json::Value params, - llvm::json::Value id) { - sendMessage(llvm::json::Object{ - {"jsonrpc", "2.0"}, - {"id", std::move(id)}, - {"method", method}, - {"params", std::move(params)}, - }); -} -void JSONTransport::reply(llvm::json::Value id, - llvm::Expected result) { - if (result) { - return sendMessage(llvm::json::Object{ - {"jsonrpc", "2.0"}, - {"id", std::move(id)}, - {"result", std::move(*result)}, - }); - } - - sendMessage(llvm::json::Object{ - {"jsonrpc", "2.0"}, - {"id", std::move(id)}, - {"error", encodeError(result.takeError())}, - }); -} - -llvm::Error JSONTransport::run(MessageHandler &handler) { - std::string json; - while (!in->isEndOfInput()) { - if (in->hasError()) { - return llvm::errorCodeToError( - std::error_code(errno, std::system_category())); - } - - if (succeeded(in->readMessage(json))) { - if (llvm::Expected doc = llvm::json::parse(json)) { - if (!handleMessage(std::move(*doc), handler)) - return llvm::Error::success(); - } else { - Logger::error("JSON parse error: {0}", llvm::toString(doc.takeError())); - } - } - } - return llvm::errorCodeToError(std::make_error_code(std::errc::io_error)); -} - -void JSONTransport::sendMessage(llvm::json::Value msg) { - outputBuffer.clear(); - llvm::raw_svector_ostream os(outputBuffer); - os << llvm::formatv(prettyOutput ? "{0:2}\n" : "{0}", msg); - out << "Content-Length: " << outputBuffer.size() << "\r\n\r\n" - << outputBuffer; - out.flush(); - Logger::debug(">>> {0}\n", outputBuffer); -} - -bool JSONTransport::handleMessage(llvm::json::Value msg, - MessageHandler &handler) { - // Message must be an object with "jsonrpc":"2.0". - llvm::json::Object *object = msg.getAsObject(); - if (!object || - object->getString("jsonrpc") != std::optional("2.0")) - return false; - - // `id` may be any JSON value. If absent, this is a notification. - std::optional id; - if (llvm::json::Value *i = object->get("id")) - id = std::move(*i); - std::optional method = object->getString("method"); - - // This is a response. - if (!method) { - if (!id) - return false; - if (auto *err = object->getObject("error")) - return handler.onReply(std::move(*id), decodeError(*err)); - // result should be given, use null if not. - llvm::json::Value result = nullptr; - if (llvm::json::Value *r = object->get("result")) - result = std::move(*r); - return handler.onReply(std::move(*id), std::move(result)); - } - - // Params should be given, use null if not. - llvm::json::Value params = nullptr; - if (llvm::json::Value *p = object->get("params")) - params = std::move(*p); - - if (id) - return handler.onCall(*method, std::move(params), std::move(*id)); - return handler.onNotify(*method, std::move(params)); -} - -/// Tries to read a line up to and including \n. -/// If failing, feof(), ferror(), or shutdownRequested() will be set. -LogicalResult readLine(std::FILE *in, SmallVectorImpl &out) { - // Big enough to hold any reasonable header line. May not fit content lines - // in delimited mode, but performance doesn't matter for that mode. - static constexpr int bufSize = 128; - size_t size = 0; - out.clear(); - for (;;) { - out.resize_for_overwrite(size + bufSize); - if (!std::fgets(&out[size], bufSize, in)) - return failure(); - - clearerr(in); - - // If the line contained null bytes, anything after it (including \n) will - // be ignored. Fortunately this is not a legal header or JSON. - size_t read = std::strlen(&out[size]); - if (read > 0 && out[size + read - 1] == '\n') { - out.resize(size + read); - return success(); - } - size += read; - } -} - -// Returns std::nullopt when: -// - ferror(), feof(), or shutdownRequested() are set. -// - Content-Length is missing or empty (protocol error) -LogicalResult -JSONTransportInputOverFile::readStandardMessage(std::string &json) { - // A Language Server Protocol message starts with a set of HTTP headers, - // delimited by \r\n, and terminated by an empty line (\r\n). - unsigned long long contentLength = 0; - llvm::SmallString<128> line; - while (true) { - if (feof(in) || hasError() || failed(readLine(in, line))) - return failure(); - - // Content-Length is a mandatory header, and the only one we handle. - StringRef lineRef = line; - if (lineRef.consume_front("Content-Length: ")) { - llvm::getAsUnsignedInteger(lineRef.trim(), 0, contentLength); - } else if (!lineRef.trim().empty()) { - // It's another header, ignore it. - continue; - } else { - // An empty line indicates the end of headers. Go ahead and read the JSON. - break; - } - } - - // The fuzzer likes crashing us by sending "Content-Length: 9999999999999999" - if (contentLength == 0 || contentLength > 1 << 30) - return failure(); - - json.resize(contentLength); - for (size_t pos = 0, read; pos < contentLength; pos += read) { - read = std::fread(&json[pos], 1, contentLength - pos, in); - if (read == 0) - return failure(); - - // If we're done, the error was transient. If we're not done, either it was - // transient or we'll see it again on retry. - clearerr(in); - pos += read; - } - return success(); -} - -/// For lit tests we support a simplified syntax: -/// - messages are delimited by '// -----' on a line by itself -/// - lines starting with // are ignored. -/// This is a testing path, so favor simplicity over performance here. -/// When returning failure: feof(), ferror(), or shutdownRequested() will be -/// set. -LogicalResult -JSONTransportInputOverFile::readDelimitedMessage(std::string &json) { - json.clear(); - llvm::SmallString<128> line; - while (succeeded(readLine(in, line))) { - StringRef lineRef = line.str().trim(); - if (lineRef.starts_with("//")) { - // Found a delimiter for the message. - if (lineRef == kDefaultSplitMarker) - break; - continue; - } - - json += line; - } - - return failure(ferror(in)); -} diff --git a/mlir/lib/Tools/mlir-lsp-server/CMakeLists.txt b/mlir/lib/Tools/mlir-lsp-server/CMakeLists.txt index d04d5156fb3c5..e2acba54e5624 100644 --- a/mlir/lib/Tools/mlir-lsp-server/CMakeLists.txt +++ b/mlir/lib/Tools/mlir-lsp-server/CMakeLists.txt @@ -7,6 +7,9 @@ add_mlir_library(MLIRLspServerLib ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Tools/mlir-lsp-server + LINK_COMPONENTS + SupportLSP + LINK_LIBS PUBLIC MLIRBytecodeWriter MLIRFunctionInterfaces diff --git a/mlir/lib/Tools/mlir-lsp-server/LSPServer.cpp b/mlir/lib/Tools/mlir-lsp-server/LSPServer.cpp index 9b937db0c6a7a..1bbbcdecb57af 100644 --- a/mlir/lib/Tools/mlir-lsp-server/LSPServer.cpp +++ b/mlir/lib/Tools/mlir-lsp-server/LSPServer.cpp @@ -9,8 +9,8 @@ #include "LSPServer.h" #include "MLIRServer.h" #include "Protocol.h" -#include "mlir/Tools/lsp-server-support/Logging.h" -#include "mlir/Tools/lsp-server-support/Transport.h" +#include "llvm/Support/LSP/Logging.h" +#include "llvm/Support/LSP/Transport.h" #include #define DEBUG_TYPE "mlir-lsp-server" @@ -18,6 +18,33 @@ using namespace mlir; using namespace mlir::lsp; +using llvm::lsp::Callback; +using llvm::lsp::CodeAction; +using llvm::lsp::CodeActionParams; +using llvm::lsp::CompletionList; +using llvm::lsp::CompletionParams; +using llvm::lsp::DidChangeTextDocumentParams; +using llvm::lsp::DidCloseTextDocumentParams; +using llvm::lsp::DidOpenTextDocumentParams; +using llvm::lsp::DocumentSymbol; +using llvm::lsp::DocumentSymbolParams; +using llvm::lsp::Hover; +using llvm::lsp::InitializedParams; +using llvm::lsp::InitializeParams; +using llvm::lsp::JSONTransport; +using llvm::lsp::Location; +using llvm::lsp::Logger; +using llvm::lsp::MessageHandler; +using llvm::lsp::MLIRConvertBytecodeParams; +using llvm::lsp::MLIRConvertBytecodeResult; +using llvm::lsp::NoParams; +using llvm::lsp::OutgoingNotification; +using llvm::lsp::PublishDiagnosticsParams; +using llvm::lsp::ReferenceParams; +using llvm::lsp::TextDocumentPositionParams; +using llvm::lsp::TextDocumentSyncKind; +using llvm::lsp::URIForFile; + //===----------------------------------------------------------------------===// // LSPServer //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Tools/mlir-lsp-server/LSPServer.h b/mlir/lib/Tools/mlir-lsp-server/LSPServer.h index 2c50c6b4ac6f5..d652899633255 100644 --- a/mlir/lib/Tools/mlir-lsp-server/LSPServer.h +++ b/mlir/lib/Tools/mlir-lsp-server/LSPServer.h @@ -13,17 +13,19 @@ namespace llvm { struct LogicalResult; +namespace lsp { +class JSONTransport; +} // namespace lsp } // namespace llvm namespace mlir { namespace lsp { -class JSONTransport; class MLIRServer; /// Run the main loop of the LSP server using the given MLIR server and /// transport. llvm::LogicalResult runMlirLSPServer(MLIRServer &server, - JSONTransport &transport); + llvm::lsp::JSONTransport &transport); } // namespace lsp } // namespace mlir diff --git a/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp b/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp index 61987525a5ca5..47b4328d0d9ec 100644 --- a/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp +++ b/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp @@ -16,10 +16,10 @@ #include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/Parser/Parser.h" #include "mlir/Support/ToolUtilities.h" -#include "mlir/Tools/lsp-server-support/Logging.h" #include "mlir/Tools/lsp-server-support/SourceMgrUtils.h" #include "llvm/ADT/StringExtras.h" #include "llvm/Support/Base64.h" +#include "llvm/Support/LSP/Logging.h" #include "llvm/Support/SourceMgr.h" #include @@ -39,9 +39,9 @@ static std::optional getLocationFromLoc(StringRef uriScheme, llvm::Expected sourceURI = lsp::URIForFile::fromFile(loc.getFilename(), uriScheme); if (!sourceURI) { - lsp::Logger::error("Failed to create URI for file `{0}`: {1}", - loc.getFilename(), - llvm::toString(sourceURI.takeError())); + llvm::lsp::Logger::error("Failed to create URI for file `{0}`: {1}", + loc.getFilename(), + llvm::toString(sourceURI.takeError())); return std::nullopt; } @@ -217,22 +217,22 @@ static lsp::Diagnostic getLspDiagnoticFromDiag(llvm::SourceMgr &sourceMgr, // Convert the severity for the diagnostic. switch (diag.getSeverity()) { - case DiagnosticSeverity::Note: + case mlir::DiagnosticSeverity::Note: llvm_unreachable("expected notes to be handled separately"); - case DiagnosticSeverity::Warning: - lspDiag.severity = lsp::DiagnosticSeverity::Warning; + case mlir::DiagnosticSeverity::Warning: + lspDiag.severity = llvm::lsp::DiagnosticSeverity::Warning; break; - case DiagnosticSeverity::Error: - lspDiag.severity = lsp::DiagnosticSeverity::Error; + case mlir::DiagnosticSeverity::Error: + lspDiag.severity = llvm::lsp::DiagnosticSeverity::Error; break; - case DiagnosticSeverity::Remark: - lspDiag.severity = lsp::DiagnosticSeverity::Information; + case mlir::DiagnosticSeverity::Remark: + lspDiag.severity = llvm::lsp::DiagnosticSeverity::Information; break; } lspDiag.message = diag.str(); // Attach any notes to the main diagnostic as related information. - std::vector relatedDiags; + std::vector relatedDiags; for (Diagnostic ¬e : diag.getNotes()) { lsp::Location noteLoc; if (std::optional loc = @@ -317,7 +317,7 @@ struct MLIRDocument { void getCodeActionForDiagnostic(const lsp::URIForFile &uri, lsp::Position &pos, StringRef severity, StringRef message, - std::vector &edits); + std::vector &edits); //===--------------------------------------------------------------------===// // Bytecode @@ -355,7 +355,8 @@ MLIRDocument::MLIRDocument(MLIRContext &context, const lsp::URIForFile &uri, // Try to parsed the given IR string. auto memBuffer = llvm::MemoryBuffer::getMemBufferCopy(contents, uri.file()); if (!memBuffer) { - lsp::Logger::error("Failed to create memory buffer for file", uri.file()); + llvm::lsp::Logger::error("Failed to create memory buffer for file", + uri.file()); return; } @@ -695,8 +696,8 @@ void MLIRDocument::findDocumentSymbols( if (SymbolOpInterface symbol = dyn_cast(op)) { symbols.emplace_back(symbol.getName(), isa(op) - ? lsp::SymbolKind::Function - : lsp::SymbolKind::Class, + ? llvm::lsp::SymbolKind::Function + : llvm::lsp::SymbolKind::Class, lsp::Range(sourceMgr, def->scopeLoc), lsp::Range(sourceMgr, def->loc)); childSymbols = &symbols.back().children; @@ -704,9 +705,9 @@ void MLIRDocument::findDocumentSymbols( } else if (op->hasTrait()) { // Otherwise, if this is a symbol table push an anonymous document symbol. symbols.emplace_back("<" + op->getName().getStringRef() + ">", - lsp::SymbolKind::Namespace, - lsp::Range(sourceMgr, def->scopeLoc), - lsp::Range(sourceMgr, def->loc)); + llvm::lsp::SymbolKind::Namespace, + llvm::lsp::Range(sourceMgr, def->scopeLoc), + llvm::lsp::Range(sourceMgr, def->loc)); childSymbols = &symbols.back().children; } } @@ -734,9 +735,9 @@ class LSPCodeCompleteContext : public AsmParserCodeCompleteContext { /// Signal code completion for a dialect name, with an optional prefix. void completeDialectName(StringRef prefix) final { for (StringRef dialect : ctx->getAvailableDialects()) { - lsp::CompletionItem item(prefix + dialect, - lsp::CompletionItemKind::Module, - /*sortText=*/"3"); + llvm::lsp::CompletionItem item(prefix + dialect, + llvm::lsp::CompletionItemKind::Module, + /*sortText=*/"3"); item.detail = "dialect"; completionList.items.emplace_back(item); } @@ -753,9 +754,9 @@ class LSPCodeCompleteContext : public AsmParserCodeCompleteContext { if (&op.getDialect() != dialect) continue; - lsp::CompletionItem item( + llvm::lsp::CompletionItem item( op.getStringRef().drop_front(dialectName.size() + 1), - lsp::CompletionItemKind::Field, + llvm::lsp::CompletionItemKind::Field, /*sortText=*/"1"); item.detail = "operation"; completionList.items.emplace_back(item); @@ -768,7 +769,8 @@ class LSPCodeCompleteContext : public AsmParserCodeCompleteContext { // Check if we need to insert the `%` or not. bool stripPrefix = getCodeCompleteLoc().getPointer()[-1] == '%'; - lsp::CompletionItem item(name, lsp::CompletionItemKind::Variable); + llvm::lsp::CompletionItem item(name, + llvm::lsp::CompletionItemKind::Variable); if (stripPrefix) item.insertText = name.drop_front(1).str(); item.detail = std::move(typeData); @@ -781,7 +783,7 @@ class LSPCodeCompleteContext : public AsmParserCodeCompleteContext { // Check if we need to insert the `^` or not. bool stripPrefix = getCodeCompleteLoc().getPointer()[-1] == '^'; - lsp::CompletionItem item(name, lsp::CompletionItemKind::Field); + llvm::lsp::CompletionItem item(name, llvm::lsp::CompletionItemKind::Field); if (stripPrefix) item.insertText = name.drop_front(1).str(); completionList.items.emplace_back(item); @@ -790,8 +792,9 @@ class LSPCodeCompleteContext : public AsmParserCodeCompleteContext { /// Signal a completion for the given expected token. void completeExpectedTokens(ArrayRef tokens, bool optional) final { for (StringRef token : tokens) { - lsp::CompletionItem item(token, lsp::CompletionItemKind::Keyword, - /*sortText=*/"0"); + llvm::lsp::CompletionItem item(token, + llvm::lsp::CompletionItemKind::Keyword, + /*sortText=*/"0"); item.detail = optional ? "optional" : ""; completionList.items.emplace_back(item); } @@ -802,7 +805,7 @@ class LSPCodeCompleteContext : public AsmParserCodeCompleteContext { appendSimpleCompletions({"affine_set", "affine_map", "dense", "dense_resource", "false", "loc", "sparse", "true", "unit"}, - lsp::CompletionItemKind::Field, + llvm::lsp::CompletionItemKind::Field, /*sortText=*/"1"); completeDialectName("#"); @@ -820,13 +823,14 @@ class LSPCodeCompleteContext : public AsmParserCodeCompleteContext { appendSimpleCompletions({"memref", "tensor", "complex", "tuple", "vector", "bf16", "f16", "f32", "f64", "f80", "f128", "index", "none"}, - lsp::CompletionItemKind::Field, + llvm::lsp::CompletionItemKind::Field, /*sortText=*/"1"); // Handle the builtin integer types. for (StringRef type : {"i", "si", "ui"}) { - lsp::CompletionItem item(type + "", lsp::CompletionItemKind::Field, - /*sortText=*/"1"); + llvm::lsp::CompletionItem item(type + "", + llvm::lsp::CompletionItemKind::Field, + /*sortText=*/"1"); item.insertText = type.str(); completionList.items.emplace_back(item); } @@ -846,9 +850,9 @@ class LSPCodeCompleteContext : public AsmParserCodeCompleteContext { void completeAliases(const llvm::StringMap &aliases, StringRef prefix = "") { for (const auto &alias : aliases) { - lsp::CompletionItem item(prefix + alias.getKey(), - lsp::CompletionItemKind::Field, - /*sortText=*/"2"); + llvm::lsp::CompletionItem item(prefix + alias.getKey(), + llvm::lsp::CompletionItemKind::Field, + /*sortText=*/"2"); llvm::raw_string_ostream(item.detail) << "alias: " << alias.getValue(); completionList.items.emplace_back(item); } @@ -856,7 +860,7 @@ class LSPCodeCompleteContext : public AsmParserCodeCompleteContext { /// Add a set of simple completions that all have the same kind. void appendSimpleCompletions(ArrayRef completions, - lsp::CompletionItemKind kind, + llvm::lsp::CompletionItemKind kind, StringRef sortText = "") { for (StringRef completion : completions) completionList.items.emplace_back(completion, kind, sortText); @@ -897,7 +901,7 @@ MLIRDocument::getCodeCompletion(const lsp::URIForFile &uri, void MLIRDocument::getCodeActionForDiagnostic( const lsp::URIForFile &uri, lsp::Position &pos, StringRef severity, - StringRef message, std::vector &edits) { + StringRef message, std::vector &edits) { // Ignore diagnostics that print the current operation. These are always // enabled for the language server, but not generally during normal // parsing/verification. @@ -913,7 +917,7 @@ void MLIRDocument::getCodeActionForDiagnostic( // Add a text edit for adding an expected-* diagnostic check for this // diagnostic. - lsp::TextEdit edit; + llvm::lsp::TextEdit edit; edit.range = lsp::Range(lsp::Position(pos.line, 0)); // Use the indent of the current line for the expected-* diagnostic. @@ -937,13 +941,14 @@ MLIRDocument::convertToBytecode() { // conceptually be relaxed. if (!llvm::hasSingleElement(parsedIR)) { if (parsedIR.empty()) { - return llvm::make_error( + return llvm::make_error( "expected a single and valid top-level operation, please ensure " "there are no errors", - lsp::ErrorCode::RequestFailed); + llvm::lsp::ErrorCode::RequestFailed); } - return llvm::make_error( - "expected a single top-level operation", lsp::ErrorCode::RequestFailed); + return llvm::make_error( + "expected a single top-level operation", + llvm::lsp::ErrorCode::RequestFailed); } lsp::MLIRConvertBytecodeResult result; @@ -1134,7 +1139,7 @@ void MLIRTextFile::findDocumentSymbols( lsp::Position endPos((i == e - 1) ? totalNumLines - 1 : chunks[i + 1]->lineOffset); lsp::DocumentSymbol symbol("", - lsp::SymbolKind::Namespace, + llvm::lsp::SymbolKind::Namespace, /*range=*/lsp::Range(startPos, endPos), /*selectionRange=*/lsp::Range(startPos)); chunk.document.findDocumentSymbols(symbol.children); @@ -1167,10 +1172,10 @@ lsp::CompletionList MLIRTextFile::getCodeCompletion(const lsp::URIForFile &uri, uri, completePos, context.getDialectRegistry()); // Adjust any completion locations. - for (lsp::CompletionItem &item : completionList.items) { + for (llvm::lsp::CompletionItem &item : completionList.items) { if (item.textEdit) chunk.adjustLocForChunkOffset(item.textEdit->range); - for (lsp::TextEdit &edit : item.additionalTextEdits) + for (llvm::lsp::TextEdit &edit : item.additionalTextEdits) chunk.adjustLocForChunkOffset(edit.range); } return completionList; @@ -1194,10 +1199,10 @@ void MLIRTextFile::getCodeActions(const lsp::URIForFile &uri, StringRef severity; switch (diag.severity) { - case lsp::DiagnosticSeverity::Error: + case llvm::lsp::DiagnosticSeverity::Error: severity = "error"; break; - case lsp::DiagnosticSeverity::Warning: + case llvm::lsp::DiagnosticSeverity::Warning: severity = "warning"; break; default: @@ -1205,7 +1210,7 @@ void MLIRTextFile::getCodeActions(const lsp::URIForFile &uri, } // Get edits for the diagnostic. - std::vector edits; + std::vector edits; chunk.document.getCodeActionForDiagnostic(uri, diagPos, severity, diag.message, edits); @@ -1221,7 +1226,7 @@ void MLIRTextFile::getCodeActions(const lsp::URIForFile &uri, } } // Fixup the locations for any edits. - for (lsp::TextEdit &edit : edits) + for (llvm::lsp::TextEdit &edit : edits) chunk.adjustLocForChunkOffset(edit.range); action.edit.emplace(); @@ -1236,9 +1241,9 @@ llvm::Expected MLIRTextFile::convertToBytecode() { // Bail out if there is more than one chunk, bytecode wants a single module. if (chunks.size() != 1) { - return llvm::make_error( + return llvm::make_error( "unexpected split file, please remove all `// -----`", - lsp::ErrorCode::RequestFailed); + llvm::lsp::ErrorCode::RequestFailed); } return chunks.front()->document.convertToBytecode(); } @@ -1283,7 +1288,7 @@ lsp::MLIRServer::~MLIRServer() = default; void lsp::MLIRServer::addOrUpdateDocument( const URIForFile &uri, StringRef contents, int64_t version, - std::vector &diagnostics) { + std::vector &diagnostics) { impl->files[uri.file()] = std::make_unique( uri, contents, version, impl->registry_fn, diagnostics); } @@ -1298,17 +1303,17 @@ std::optional lsp::MLIRServer::removeDocument(const URIForFile &uri) { return version; } -void lsp::MLIRServer::getLocationsOf(const URIForFile &uri, - const Position &defPos, - std::vector &locations) { +void lsp::MLIRServer::getLocationsOf( + const URIForFile &uri, const Position &defPos, + std::vector &locations) { auto fileIt = impl->files.find(uri.file()); if (fileIt != impl->files.end()) fileIt->second->getLocationsOf(uri, defPos, locations); } -void lsp::MLIRServer::findReferencesOf(const URIForFile &uri, - const Position &pos, - std::vector &references) { +void lsp::MLIRServer::findReferencesOf( + const URIForFile &uri, const Position &pos, + std::vector &references) { auto fileIt = impl->files.find(uri.file()); if (fileIt != impl->files.end()) fileIt->second->findReferencesOf(uri, pos, references); @@ -1367,17 +1372,17 @@ lsp::MLIRServer::convertFromBytecode(const URIForFile &uri) { // Try to parse the given source file. Block parsedBlock; if (failed(parseSourceFile(uri.file(), &parsedBlock, parserConfig))) { - return llvm::make_error( + return llvm::make_error( "failed to parse bytecode source file: " + errorMsg, - lsp::ErrorCode::RequestFailed); + llvm::lsp::ErrorCode::RequestFailed); } // TODO: We currently expect a single top-level operation, but this could // conceptually be relaxed. if (!llvm::hasSingleElement(parsedBlock)) { - return llvm::make_error( + return llvm::make_error( "expected bytecode to contain a single top-level operation", - lsp::ErrorCode::RequestFailed); + llvm::lsp::ErrorCode::RequestFailed); } // Print the module to a buffer. @@ -1401,9 +1406,9 @@ llvm::Expected lsp::MLIRServer::convertToBytecode(const URIForFile &uri) { auto fileIt = impl->files.find(uri.file()); if (fileIt == impl->files.end()) { - return llvm::make_error( + return llvm::make_error( "language server does not contain an entry for this source file", - lsp::ErrorCode::RequestFailed); + llvm::lsp::ErrorCode::RequestFailed); } return fileIt->second->convertToBytecode(); } diff --git a/mlir/lib/Tools/mlir-lsp-server/MLIRServer.h b/mlir/lib/Tools/mlir-lsp-server/MLIRServer.h index 85e69e69f6631..31a01fec8bbc9 100644 --- a/mlir/lib/Tools/mlir-lsp-server/MLIRServer.h +++ b/mlir/lib/Tools/mlir-lsp-server/MLIRServer.h @@ -9,6 +9,7 @@ #ifndef LIB_MLIR_TOOLS_MLIRLSPSERVER_SERVER_H_ #define LIB_MLIR_TOOLS_MLIRLSPSERVER_SERVER_H_ +#include "Protocol.h" #include "mlir/Support/LLVM.h" #include "mlir/Tools/mlir-lsp-server/MlirLspRegistryFunction.h" #include "llvm/Support/Error.h" @@ -19,16 +20,17 @@ namespace mlir { class DialectRegistry; namespace lsp { -struct CodeAction; -struct CodeActionContext; -struct CompletionList; -struct Diagnostic; -struct DocumentSymbol; -struct Hover; -struct Location; -struct MLIRConvertBytecodeResult; -struct Position; -struct Range; +using llvm::lsp::CodeAction; +using llvm::lsp::CodeActionContext; +using llvm::lsp::CompletionList; +using llvm::lsp::Diagnostic; +using llvm::lsp::DocumentSymbol; +using llvm::lsp::Hover; +using llvm::lsp::Location; +using llvm::lsp::MLIRConvertBytecodeResult; +using llvm::lsp::Position; +using llvm::lsp::Range; +using llvm::lsp::URIForFile; /// This class implements all of the MLIR related functionality necessary for a /// language server. This class allows for keeping the MLIR specific logic diff --git a/mlir/lib/Tools/mlir-lsp-server/MlirLspServerMain.cpp b/mlir/lib/Tools/mlir-lsp-server/MlirLspServerMain.cpp index f1dc32615c6a3..d4589b240e39e 100644 --- a/mlir/lib/Tools/mlir-lsp-server/MlirLspServerMain.cpp +++ b/mlir/lib/Tools/mlir-lsp-server/MlirLspServerMain.cpp @@ -9,14 +9,18 @@ #include "mlir/Tools/mlir-lsp-server/MlirLspServerMain.h" #include "LSPServer.h" #include "MLIRServer.h" -#include "mlir/Tools/lsp-server-support/Logging.h" -#include "mlir/Tools/lsp-server-support/Transport.h" #include "llvm/Support/CommandLine.h" +#include "llvm/Support/LSP/Logging.h" +#include "llvm/Support/LSP/Transport.h" #include "llvm/Support/Program.h" using namespace mlir; using namespace mlir::lsp; +using llvm::lsp::JSONStreamStyle; +using llvm::lsp::JSONTransport; +using llvm::lsp::Logger; + LogicalResult mlir::MlirLspServerMain(int argc, char **argv, DialectRegistryFn registry_fn) { llvm::cl::opt inputStyle{ diff --git a/mlir/lib/Tools/mlir-lsp-server/Protocol.cpp b/mlir/lib/Tools/mlir-lsp-server/Protocol.cpp index a56e9a10f03f1..28aded304d388 100644 --- a/mlir/lib/Tools/mlir-lsp-server/Protocol.cpp +++ b/mlir/lib/Tools/mlir-lsp-server/Protocol.cpp @@ -13,14 +13,11 @@ #include "Protocol.h" #include "llvm/Support/JSON.h" -using namespace mlir; -using namespace mlir::lsp; - //===----------------------------------------------------------------------===// // MLIRConvertBytecodeParams //===----------------------------------------------------------------------===// -bool mlir::lsp::fromJSON(const llvm::json::Value &value, +bool llvm::lsp::fromJSON(const llvm::json::Value &value, MLIRConvertBytecodeParams &result, llvm::json::Path path) { llvm::json::ObjectMapper o(value, path); @@ -31,6 +28,6 @@ bool mlir::lsp::fromJSON(const llvm::json::Value &value, // MLIRConvertBytecodeResult //===----------------------------------------------------------------------===// -llvm::json::Value mlir::lsp::toJSON(const MLIRConvertBytecodeResult &value) { +llvm::json::Value llvm::lsp::toJSON(const MLIRConvertBytecodeResult &value) { return llvm::json::Object{{"output", value.output}}; } diff --git a/mlir/lib/Tools/mlir-lsp-server/Protocol.h b/mlir/lib/Tools/mlir-lsp-server/Protocol.h index d910780e1ee92..ed0db4e591d8f 100644 --- a/mlir/lib/Tools/mlir-lsp-server/Protocol.h +++ b/mlir/lib/Tools/mlir-lsp-server/Protocol.h @@ -20,9 +20,9 @@ #ifndef LIB_MLIR_TOOLS_MLIRLSPSERVER_PROTOCOL_H_ #define LIB_MLIR_TOOLS_MLIRLSPSERVER_PROTOCOL_H_ -#include "mlir/Tools/lsp-server-support/Protocol.h" +#include "llvm/Support/LSP/Protocol.h" -namespace mlir { +namespace llvm { namespace lsp { //===----------------------------------------------------------------------===// // MLIRConvertBytecodeParams @@ -54,6 +54,6 @@ struct MLIRConvertBytecodeResult { llvm::json::Value toJSON(const MLIRConvertBytecodeResult &value); } // namespace lsp -} // namespace mlir +} // namespace llvm #endif diff --git a/mlir/lib/Tools/mlir-pdll-lsp-server/CMakeLists.txt b/mlir/lib/Tools/mlir-pdll-lsp-server/CMakeLists.txt index bf25b7e0a64f3..b41603fb67eb0 100644 --- a/mlir/lib/Tools/mlir-pdll-lsp-server/CMakeLists.txt +++ b/mlir/lib/Tools/mlir-pdll-lsp-server/CMakeLists.txt @@ -7,6 +7,9 @@ llvm_add_library(MLIRPdllLspServerLib ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Tools/mlir-pdll-lsp-server + LINK_COMPONENTS + SupportLSP + LINK_LIBS PUBLIC MLIRPDLLCodeGen MLIRPDLLParser diff --git a/mlir/lib/Tools/mlir-pdll-lsp-server/LSPServer.cpp b/mlir/lib/Tools/mlir-pdll-lsp-server/LSPServer.cpp index 82542a12a1807..7b23adcc7e2e1 100644 --- a/mlir/lib/Tools/mlir-pdll-lsp-server/LSPServer.cpp +++ b/mlir/lib/Tools/mlir-pdll-lsp-server/LSPServer.cpp @@ -10,8 +10,9 @@ #include "PDLLServer.h" #include "Protocol.h" -#include "mlir/Tools/lsp-server-support/Logging.h" -#include "mlir/Tools/lsp-server-support/Transport.h" +#include "llvm/Support/LSP/Logging.h" +#include "llvm/Support/LSP/Protocol.h" +#include "llvm/Support/LSP/Transport.h" #include #define DEBUG_TYPE "pdll-lsp-server" @@ -19,6 +20,30 @@ using namespace mlir; using namespace mlir::lsp; +using llvm::lsp::Callback; +using llvm::lsp::CompletionList; +using llvm::lsp::CompletionParams; +using llvm::lsp::DidChangeTextDocumentParams; +using llvm::lsp::DidCloseTextDocumentParams; +using llvm::lsp::DidOpenTextDocumentParams; +using llvm::lsp::DocumentLinkParams; +using llvm::lsp::DocumentSymbol; +using llvm::lsp::DocumentSymbolParams; +using llvm::lsp::Hover; +using llvm::lsp::InitializedParams; +using llvm::lsp::InitializeParams; +using llvm::lsp::InlayHintsParams; +using llvm::lsp::JSONTransport; +using llvm::lsp::Location; +using llvm::lsp::Logger; +using llvm::lsp::MessageHandler; +using llvm::lsp::NoParams; +using llvm::lsp::OutgoingNotification; +using llvm::lsp::PublishDiagnosticsParams; +using llvm::lsp::ReferenceParams; +using llvm::lsp::TextDocumentPositionParams; +using llvm::lsp::TextDocumentSyncKind; + //===----------------------------------------------------------------------===// // LSPServer //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Tools/mlir-pdll-lsp-server/LSPServer.h b/mlir/lib/Tools/mlir-pdll-lsp-server/LSPServer.h index 78c4c31100cbc..42c0a5d7b6d2b 100644 --- a/mlir/lib/Tools/mlir-pdll-lsp-server/LSPServer.h +++ b/mlir/lib/Tools/mlir-pdll-lsp-server/LSPServer.h @@ -13,17 +13,19 @@ namespace llvm { struct LogicalResult; +namespace lsp { +class JSONTransport; +} // namespace lsp } // namespace llvm namespace mlir { namespace lsp { -class JSONTransport; class PDLLServer; /// Run the main loop of the LSP server using the given PDLL server and /// transport. llvm::LogicalResult runPdllLSPServer(PDLLServer &server, - JSONTransport &transport); + llvm::lsp::JSONTransport &transport); } // namespace lsp } // namespace mlir diff --git a/mlir/lib/Tools/mlir-pdll-lsp-server/MlirPdllLspServerMain.cpp b/mlir/lib/Tools/mlir-pdll-lsp-server/MlirPdllLspServerMain.cpp index 287a131ecd17d..5dea130675cdb 100644 --- a/mlir/lib/Tools/mlir-pdll-lsp-server/MlirPdllLspServerMain.cpp +++ b/mlir/lib/Tools/mlir-pdll-lsp-server/MlirPdllLspServerMain.cpp @@ -9,14 +9,17 @@ #include "mlir/Tools/mlir-pdll-lsp-server/MlirPdllLspServerMain.h" #include "LSPServer.h" #include "PDLLServer.h" -#include "mlir/Tools/lsp-server-support/Logging.h" -#include "mlir/Tools/lsp-server-support/Transport.h" #include "llvm/Support/CommandLine.h" +#include "llvm/Support/LSP/Logging.h" +#include "llvm/Support/LSP/Transport.h" #include "llvm/Support/Program.h" using namespace mlir; using namespace mlir::lsp; +using llvm::lsp::JSONStreamStyle; +using llvm::lsp::Logger; + LogicalResult mlir::MlirPdllLspServerMain(int argc, char **argv) { llvm::cl::opt inputStyle{ "input-style", @@ -72,7 +75,8 @@ LogicalResult mlir::MlirPdllLspServerMain(int argc, char **argv) { // Configure the transport used for communication. llvm::sys::ChangeStdinToBinary(); - JSONTransport transport(stdin, llvm::outs(), inputStyle, prettyPrint); + llvm::lsp::JSONTransport transport(stdin, llvm::outs(), inputStyle, + prettyPrint); // Configure the servers and start the main language server. PDLLServer::Options options(compilationDatabases, extraIncludeDirs); diff --git a/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp b/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp index 84f529ae16401..60b9567ff7804 100644 --- a/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp +++ b/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp @@ -23,13 +23,13 @@ #include "mlir/Tools/PDLL/Parser/CodeComplete.h" #include "mlir/Tools/PDLL/Parser/Parser.h" #include "mlir/Tools/lsp-server-support/CompilationDatabase.h" -#include "mlir/Tools/lsp-server-support/Logging.h" #include "mlir/Tools/lsp-server-support/SourceMgrUtils.h" #include "llvm/ADT/IntervalMap.h" #include "llvm/ADT/StringMap.h" #include "llvm/ADT/StringSet.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/FileSystem.h" +#include "llvm/Support/LSP/Logging.h" #include "llvm/Support/Path.h" #include @@ -38,17 +38,19 @@ using namespace mlir::pdll; /// Returns a language server uri for the given source location. `mainFileURI` /// corresponds to the uri for the main file of the source manager. -static lsp::URIForFile getURIFromLoc(llvm::SourceMgr &mgr, SMRange loc, - const lsp::URIForFile &mainFileURI) { +static llvm::lsp::URIForFile +getURIFromLoc(llvm::SourceMgr &mgr, SMRange loc, + const llvm::lsp::URIForFile &mainFileURI) { int bufferId = mgr.FindBufferContainingLoc(loc.Start); if (bufferId == 0 || bufferId == static_cast(mgr.getMainFileID())) return mainFileURI; - llvm::Expected fileForLoc = lsp::URIForFile::fromFile( - mgr.getBufferInfo(bufferId).Buffer->getBufferIdentifier()); + llvm::Expected fileForLoc = + llvm::lsp::URIForFile::fromFile( + mgr.getBufferInfo(bufferId).Buffer->getBufferIdentifier()); if (fileForLoc) return *fileForLoc; - lsp::Logger::error("Failed to create URI for include file: {0}", - llvm::toString(fileForLoc.takeError())); + llvm::lsp::Logger::error("Failed to create URI for include file: {0}", + llvm::toString(fileForLoc.takeError())); return mainFileURI; } @@ -59,16 +61,18 @@ static bool isMainFileLoc(llvm::SourceMgr &mgr, SMRange loc) { } /// Returns a language server location from the given source range. -static lsp::Location getLocationFromLoc(llvm::SourceMgr &mgr, SMRange range, - const lsp::URIForFile &uri) { - return lsp::Location(getURIFromLoc(mgr, range, uri), lsp::Range(mgr, range)); +static llvm::lsp::Location +getLocationFromLoc(llvm::SourceMgr &mgr, SMRange range, + const llvm::lsp::URIForFile &uri) { + return llvm::lsp::Location(getURIFromLoc(mgr, range, uri), + llvm::lsp::Range(mgr, range)); } /// Convert the given MLIR diagnostic to the LSP form. -static std::optional +static std::optional getLspDiagnoticFromDiag(llvm::SourceMgr &sourceMgr, const ast::Diagnostic &diag, - const lsp::URIForFile &uri) { - lsp::Diagnostic lspDiag; + const llvm::lsp::URIForFile &uri) { + llvm::lsp::Diagnostic lspDiag; lspDiag.source = "pdll"; // FIXME: Right now all of the diagnostics are treated as parser issues, but @@ -76,7 +80,8 @@ getLspDiagnoticFromDiag(llvm::SourceMgr &sourceMgr, const ast::Diagnostic &diag, lspDiag.category = "Parse Error"; // Try to grab a file location for this diagnostic. - lsp::Location loc = getLocationFromLoc(sourceMgr, diag.getLocation(), uri); + llvm::lsp::Location loc = + getLocationFromLoc(sourceMgr, diag.getLocation(), uri); lspDiag.range = loc.range; // Skip diagnostics that weren't emitted within the main file. @@ -88,19 +93,19 @@ getLspDiagnoticFromDiag(llvm::SourceMgr &sourceMgr, const ast::Diagnostic &diag, case ast::Diagnostic::Severity::DK_Note: llvm_unreachable("expected notes to be handled separately"); case ast::Diagnostic::Severity::DK_Warning: - lspDiag.severity = lsp::DiagnosticSeverity::Warning; + lspDiag.severity = llvm::lsp::DiagnosticSeverity::Warning; break; case ast::Diagnostic::Severity::DK_Error: - lspDiag.severity = lsp::DiagnosticSeverity::Error; + lspDiag.severity = llvm::lsp::DiagnosticSeverity::Error; break; case ast::Diagnostic::Severity::DK_Remark: - lspDiag.severity = lsp::DiagnosticSeverity::Information; + lspDiag.severity = llvm::lsp::DiagnosticSeverity::Information; break; } lspDiag.message = diag.getMessage().str(); // Attach any notes to the main diagnostic as related information. - std::vector relatedDiags; + std::vector relatedDiags; for (const ast::Diagnostic ¬e : diag.getNotes()) { relatedDiags.emplace_back( getLocationFromLoc(sourceMgr, note.getLocation(), uri), @@ -259,9 +264,9 @@ namespace { /// This class represents all of the information pertaining to a specific PDL /// document. struct PDLDocument { - PDLDocument(const lsp::URIForFile &uri, StringRef contents, + PDLDocument(const llvm::lsp::URIForFile &uri, StringRef contents, const std::vector &extraDirs, - std::vector &diagnostics); + std::vector &diagnostics); PDLDocument(const PDLDocument &) = delete; PDLDocument &operator=(const PDLDocument &) = delete; @@ -269,76 +274,83 @@ struct PDLDocument { // Definitions and References //===--------------------------------------------------------------------===// - void getLocationsOf(const lsp::URIForFile &uri, const lsp::Position &defPos, - std::vector &locations); - void findReferencesOf(const lsp::URIForFile &uri, const lsp::Position &pos, - std::vector &references); + void getLocationsOf(const llvm::lsp::URIForFile &uri, + const llvm::lsp::Position &defPos, + std::vector &locations); + void findReferencesOf(const llvm::lsp::URIForFile &uri, + const llvm::lsp::Position &pos, + std::vector &references); //===--------------------------------------------------------------------===// // Document Links //===--------------------------------------------------------------------===// - void getDocumentLinks(const lsp::URIForFile &uri, - std::vector &links); + void getDocumentLinks(const llvm::lsp::URIForFile &uri, + std::vector &links); //===--------------------------------------------------------------------===// // Hover //===--------------------------------------------------------------------===// - std::optional findHover(const lsp::URIForFile &uri, - const lsp::Position &hoverPos); - std::optional findHover(const ast::Decl *decl, - const SMRange &hoverRange); - lsp::Hover buildHoverForOpName(const ods::Operation *op, - const SMRange &hoverRange); - lsp::Hover buildHoverForVariable(const ast::VariableDecl *varDecl, - const SMRange &hoverRange); - lsp::Hover buildHoverForPattern(const ast::PatternDecl *decl, - const SMRange &hoverRange); - lsp::Hover buildHoverForCoreConstraint(const ast::CoreConstraintDecl *decl, + std::optional + findHover(const llvm::lsp::URIForFile &uri, + const llvm::lsp::Position &hoverPos); + std::optional findHover(const ast::Decl *decl, + const SMRange &hoverRange); + llvm::lsp::Hover buildHoverForOpName(const ods::Operation *op, + const SMRange &hoverRange); + llvm::lsp::Hover buildHoverForVariable(const ast::VariableDecl *varDecl, const SMRange &hoverRange); + llvm::lsp::Hover buildHoverForPattern(const ast::PatternDecl *decl, + const SMRange &hoverRange); + llvm::lsp::Hover + buildHoverForCoreConstraint(const ast::CoreConstraintDecl *decl, + const SMRange &hoverRange); template - lsp::Hover buildHoverForUserConstraintOrRewrite(StringRef typeName, - const T *decl, - const SMRange &hoverRange); + llvm::lsp::Hover + buildHoverForUserConstraintOrRewrite(StringRef typeName, const T *decl, + const SMRange &hoverRange); //===--------------------------------------------------------------------===// // Document Symbols //===--------------------------------------------------------------------===// - void findDocumentSymbols(std::vector &symbols); + void findDocumentSymbols(std::vector &symbols); //===--------------------------------------------------------------------===// // Code Completion //===--------------------------------------------------------------------===// - lsp::CompletionList getCodeCompletion(const lsp::URIForFile &uri, - const lsp::Position &completePos); + llvm::lsp::CompletionList + getCodeCompletion(const llvm::lsp::URIForFile &uri, + const llvm::lsp::Position &completePos); //===--------------------------------------------------------------------===// // Signature Help //===--------------------------------------------------------------------===// - lsp::SignatureHelp getSignatureHelp(const lsp::URIForFile &uri, - const lsp::Position &helpPos); + llvm::lsp::SignatureHelp getSignatureHelp(const llvm::lsp::URIForFile &uri, + const llvm::lsp::Position &helpPos); //===--------------------------------------------------------------------===// // Inlay Hints //===--------------------------------------------------------------------===// - void getInlayHints(const lsp::URIForFile &uri, const lsp::Range &range, - std::vector &inlayHints); + void getInlayHints(const llvm::lsp::URIForFile &uri, + const llvm::lsp::Range &range, + std::vector &inlayHints); void getInlayHintsFor(const ast::VariableDecl *decl, - const lsp::URIForFile &uri, - std::vector &inlayHints); - void getInlayHintsFor(const ast::CallExpr *expr, const lsp::URIForFile &uri, - std::vector &inlayHints); + const llvm::lsp::URIForFile &uri, + std::vector &inlayHints); + void getInlayHintsFor(const ast::CallExpr *expr, + const llvm::lsp::URIForFile &uri, + std::vector &inlayHints); void getInlayHintsFor(const ast::OperationExpr *expr, - const lsp::URIForFile &uri, - std::vector &inlayHints); + const llvm::lsp::URIForFile &uri, + std::vector &inlayHints); /// Add a parameter hint for the given expression using `label`. - void addParameterHintFor(std::vector &inlayHints, + void addParameterHintFor(std::vector &inlayHints, const ast::Expr *expr, StringRef label); //===--------------------------------------------------------------------===// @@ -372,13 +384,14 @@ struct PDLDocument { }; } // namespace -PDLDocument::PDLDocument(const lsp::URIForFile &uri, StringRef contents, +PDLDocument::PDLDocument(const llvm::lsp::URIForFile &uri, StringRef contents, const std::vector &extraDirs, - std::vector &diagnostics) + std::vector &diagnostics) : astContext(odsContext) { auto memBuffer = llvm::MemoryBuffer::getMemBufferCopy(contents, uri.file()); if (!memBuffer) { - lsp::Logger::error("Failed to create memory buffer for file", uri.file()); + llvm::lsp::Logger::error("Failed to create memory buffer for file", + uri.file()); return; } @@ -412,9 +425,9 @@ PDLDocument::PDLDocument(const lsp::URIForFile &uri, StringRef contents, // PDLDocument: Definitions and References //===----------------------------------------------------------------------===// -void PDLDocument::getLocationsOf(const lsp::URIForFile &uri, - const lsp::Position &defPos, - std::vector &locations) { +void PDLDocument::getLocationsOf(const llvm::lsp::URIForFile &uri, + const llvm::lsp::Position &defPos, + std::vector &locations) { SMLoc posLoc = defPos.getAsSMLoc(sourceMgr); const PDLIndexSymbol *symbol = index.lookup(posLoc); if (!symbol) @@ -423,9 +436,9 @@ void PDLDocument::getLocationsOf(const lsp::URIForFile &uri, locations.push_back(getLocationFromLoc(sourceMgr, symbol->getDefLoc(), uri)); } -void PDLDocument::findReferencesOf(const lsp::URIForFile &uri, - const lsp::Position &pos, - std::vector &references) { +void PDLDocument::findReferencesOf( + const llvm::lsp::URIForFile &uri, const llvm::lsp::Position &pos, + std::vector &references) { SMLoc posLoc = pos.getAsSMLoc(sourceMgr); const PDLIndexSymbol *symbol = index.lookup(posLoc); if (!symbol) @@ -440,8 +453,9 @@ void PDLDocument::findReferencesOf(const lsp::URIForFile &uri, // PDLDocument: Document Links //===--------------------------------------------------------------------===// -void PDLDocument::getDocumentLinks(const lsp::URIForFile &uri, - std::vector &links) { +void PDLDocument::getDocumentLinks( + const llvm::lsp::URIForFile &uri, + std::vector &links) { for (const lsp::SourceMgrInclude &include : parsedIncludes) links.emplace_back(include.range, include.uri); } @@ -450,9 +464,9 @@ void PDLDocument::getDocumentLinks(const lsp::URIForFile &uri, // PDLDocument: Hover //===----------------------------------------------------------------------===// -std::optional -PDLDocument::findHover(const lsp::URIForFile &uri, - const lsp::Position &hoverPos) { +std::optional +PDLDocument::findHover(const llvm::lsp::URIForFile &uri, + const llvm::lsp::Position &hoverPos) { SMLoc posLoc = hoverPos.getAsSMLoc(sourceMgr); // Check for a reference to an include. @@ -474,8 +488,8 @@ PDLDocument::findHover(const lsp::URIForFile &uri, return findHover(decl, hoverRange); } -std::optional PDLDocument::findHover(const ast::Decl *decl, - const SMRange &hoverRange) { +std::optional +PDLDocument::findHover(const ast::Decl *decl, const SMRange &hoverRange) { // Add hover for variables. if (const auto *varDecl = dyn_cast(decl)) return buildHoverForVariable(varDecl, hoverRange); @@ -499,9 +513,9 @@ std::optional PDLDocument::findHover(const ast::Decl *decl, return std::nullopt; } -lsp::Hover PDLDocument::buildHoverForOpName(const ods::Operation *op, - const SMRange &hoverRange) { - lsp::Hover hover(lsp::Range(sourceMgr, hoverRange)); +llvm::lsp::Hover PDLDocument::buildHoverForOpName(const ods::Operation *op, + const SMRange &hoverRange) { + llvm::lsp::Hover hover(llvm::lsp::Range(sourceMgr, hoverRange)); { llvm::raw_string_ostream hoverOS(hover.contents.value); hoverOS << "**OpName**: `" << op->getName() << "`\n***\n" @@ -511,9 +525,10 @@ lsp::Hover PDLDocument::buildHoverForOpName(const ods::Operation *op, return hover; } -lsp::Hover PDLDocument::buildHoverForVariable(const ast::VariableDecl *varDecl, - const SMRange &hoverRange) { - lsp::Hover hover(lsp::Range(sourceMgr, hoverRange)); +llvm::lsp::Hover +PDLDocument::buildHoverForVariable(const ast::VariableDecl *varDecl, + const SMRange &hoverRange) { + llvm::lsp::Hover hover(llvm::lsp::Range(sourceMgr, hoverRange)); { llvm::raw_string_ostream hoverOS(hover.contents.value); hoverOS << "**Variable**: `" << varDecl->getName().getName() << "`\n***\n" @@ -522,9 +537,9 @@ lsp::Hover PDLDocument::buildHoverForVariable(const ast::VariableDecl *varDecl, return hover; } -lsp::Hover PDLDocument::buildHoverForPattern(const ast::PatternDecl *decl, - const SMRange &hoverRange) { - lsp::Hover hover(lsp::Range(sourceMgr, hoverRange)); +llvm::lsp::Hover PDLDocument::buildHoverForPattern(const ast::PatternDecl *decl, + const SMRange &hoverRange) { + llvm::lsp::Hover hover(llvm::lsp::Range(sourceMgr, hoverRange)); { llvm::raw_string_ostream hoverOS(hover.contents.value); hoverOS << "**Pattern**"; @@ -545,10 +560,10 @@ lsp::Hover PDLDocument::buildHoverForPattern(const ast::PatternDecl *decl, return hover; } -lsp::Hover +llvm::lsp::Hover PDLDocument::buildHoverForCoreConstraint(const ast::CoreConstraintDecl *decl, const SMRange &hoverRange) { - lsp::Hover hover(lsp::Range(sourceMgr, hoverRange)); + llvm::lsp::Hover hover(llvm::lsp::Range(sourceMgr, hoverRange)); { llvm::raw_string_ostream hoverOS(hover.contents.value); hoverOS << "**Constraint**: `"; @@ -573,9 +588,9 @@ PDLDocument::buildHoverForCoreConstraint(const ast::CoreConstraintDecl *decl, } template -lsp::Hover PDLDocument::buildHoverForUserConstraintOrRewrite( +llvm::lsp::Hover PDLDocument::buildHoverForUserConstraintOrRewrite( StringRef typeName, const T *decl, const SMRange &hoverRange) { - lsp::Hover hover(lsp::Range(sourceMgr, hoverRange)); + llvm::lsp::Hover hover(llvm::lsp::Range(sourceMgr, hoverRange)); { llvm::raw_string_ostream hoverOS(hover.contents.value); hoverOS << "**" << typeName << "**: `" << decl->getName().getName() @@ -617,7 +632,7 @@ lsp::Hover PDLDocument::buildHoverForUserConstraintOrRewrite( //===----------------------------------------------------------------------===// void PDLDocument::findDocumentSymbols( - std::vector &symbols) { + std::vector &symbols) { if (failed(astModule)) return; @@ -631,25 +646,28 @@ void PDLDocument::findDocumentSymbols( SMRange nameLoc = name ? name->getLoc() : patternDecl->getLoc(); SMRange bodyLoc(nameLoc.Start, patternDecl->getBody()->getLoc().End); - symbols.emplace_back( - name ? name->getName() : "", lsp::SymbolKind::Class, - lsp::Range(sourceMgr, bodyLoc), lsp::Range(sourceMgr, nameLoc)); + symbols.emplace_back(name ? name->getName() : "", + llvm::lsp::SymbolKind::Class, + llvm::lsp::Range(sourceMgr, bodyLoc), + llvm::lsp::Range(sourceMgr, nameLoc)); } else if (const auto *cDecl = dyn_cast(decl)) { // TODO: Add source information for the code block body. SMRange nameLoc = cDecl->getName().getLoc(); SMRange bodyLoc = nameLoc; - symbols.emplace_back( - cDecl->getName().getName(), lsp::SymbolKind::Function, - lsp::Range(sourceMgr, bodyLoc), lsp::Range(sourceMgr, nameLoc)); + symbols.emplace_back(cDecl->getName().getName(), + llvm::lsp::SymbolKind::Function, + llvm::lsp::Range(sourceMgr, bodyLoc), + llvm::lsp::Range(sourceMgr, nameLoc)); } else if (const auto *cDecl = dyn_cast(decl)) { // TODO: Add source information for the code block body. SMRange nameLoc = cDecl->getName().getLoc(); SMRange bodyLoc = nameLoc; - symbols.emplace_back( - cDecl->getName().getName(), lsp::SymbolKind::Function, - lsp::Range(sourceMgr, bodyLoc), lsp::Range(sourceMgr, nameLoc)); + symbols.emplace_back(cDecl->getName().getName(), + llvm::lsp::SymbolKind::Function, + llvm::lsp::Range(sourceMgr, bodyLoc), + llvm::lsp::Range(sourceMgr, nameLoc)); } } } @@ -662,7 +680,7 @@ namespace { class LSPCodeCompleteContext : public CodeCompleteContext { public: LSPCodeCompleteContext(SMLoc completeLoc, llvm::SourceMgr &sourceMgr, - lsp::CompletionList &completionList, + llvm::lsp::CompletionList &completionList, ods::Context &odsContext, ArrayRef includeDirs) : CodeCompleteContext(completeLoc), sourceMgr(sourceMgr), @@ -674,13 +692,13 @@ class LSPCodeCompleteContext : public CodeCompleteContext { ArrayRef elementNames = tupleType.getElementNames(); for (unsigned i = 0, e = tupleType.size(); i < e; ++i) { // Push back a completion item that uses the result index. - lsp::CompletionItem item; + llvm::lsp::CompletionItem item; item.label = llvm::formatv("{0} (field #{0})", i).str(); item.insertText = Twine(i).str(); item.filterText = item.sortText = item.insertText; - item.kind = lsp::CompletionItemKind::Field; + item.kind = llvm::lsp::CompletionItemKind::Field; item.detail = llvm::formatv("{0}: {1}", i, elementTypes[i]); - item.insertTextFormat = lsp::InsertTextFormat::PlainText; + item.insertTextFormat = llvm::lsp::InsertTextFormat::PlainText; completionList.items.emplace_back(item); // If the element has a name, push back a completion item with that name. @@ -705,11 +723,11 @@ class LSPCodeCompleteContext : public CodeCompleteContext { const ods::TypeConstraint &constraint = result.getConstraint(); // Push back a completion item that uses the result index. - lsp::CompletionItem item; + llvm::lsp::CompletionItem item; item.label = llvm::formatv("{0} (field #{0})", it.index()).str(); item.insertText = Twine(it.index()).str(); item.filterText = item.sortText = item.insertText; - item.kind = lsp::CompletionItemKind::Field; + item.kind = llvm::lsp::CompletionItemKind::Field; switch (result.getVariableLengthKind()) { case ods::VariableLengthKind::Single: item.detail = llvm::formatv("{0}: Value", it.index()).str(); @@ -721,12 +739,12 @@ class LSPCodeCompleteContext : public CodeCompleteContext { item.detail = llvm::formatv("{0}: ValueRange", it.index()).str(); break; } - item.documentation = lsp::MarkupContent{ - lsp::MarkupKind::Markdown, + item.documentation = llvm::lsp::MarkupContent{ + llvm::lsp::MarkupKind::Markdown, llvm::formatv("{0}\n\n```c++\n{1}\n```\n", constraint.getSummary(), constraint.getCppClass()) .str()}; - item.insertTextFormat = lsp::InsertTextFormat::PlainText; + item.insertTextFormat = llvm::lsp::InsertTextFormat::PlainText; completionList.items.emplace_back(item); // If the result has a name, push back a completion item with the result @@ -750,16 +768,16 @@ class LSPCodeCompleteContext : public CodeCompleteContext { for (const ods::Attribute &attr : odsOp->getAttributes()) { const ods::AttributeConstraint &constraint = attr.getConstraint(); - lsp::CompletionItem item; + llvm::lsp::CompletionItem item; item.label = attr.getName().str(); - item.kind = lsp::CompletionItemKind::Field; + item.kind = llvm::lsp::CompletionItemKind::Field; item.detail = attr.isOptional() ? "optional" : ""; - item.documentation = lsp::MarkupContent{ - lsp::MarkupKind::Markdown, + item.documentation = llvm::lsp::MarkupContent{ + llvm::lsp::MarkupKind::Markdown, llvm::formatv("{0}\n\n```c++\n{1}\n```\n", constraint.getSummary(), constraint.getCppClass()) .str()}; - item.insertTextFormat = lsp::InsertTextFormat::PlainText; + item.insertTextFormat = llvm::lsp::InsertTextFormat::PlainText; completionList.items.emplace_back(item); } } @@ -769,18 +787,18 @@ class LSPCodeCompleteContext : public CodeCompleteContext { const ast::DeclScope *scope) final { auto addCoreConstraint = [&](StringRef constraint, StringRef mlirType, StringRef snippetText = "") { - lsp::CompletionItem item; + llvm::lsp::CompletionItem item; item.label = constraint.str(); - item.kind = lsp::CompletionItemKind::Class; + item.kind = llvm::lsp::CompletionItemKind::Class; item.detail = (constraint + " constraint").str(); - item.documentation = lsp::MarkupContent{ - lsp::MarkupKind::Markdown, + item.documentation = llvm::lsp::MarkupContent{ + llvm::lsp::MarkupKind::Markdown, ("A single entity core constraint of type `" + mlirType + "`").str()}; item.sortText = "0"; item.insertText = snippetText.str(); item.insertTextFormat = snippetText.empty() - ? lsp::InsertTextFormat::PlainText - : lsp::InsertTextFormat::Snippet; + ? llvm::lsp::InsertTextFormat::PlainText + : llvm::lsp::InsertTextFormat::Snippet; completionList.items.emplace_back(item); }; @@ -812,9 +830,9 @@ class LSPCodeCompleteContext : public CodeCompleteContext { while (scope) { for (const ast::Decl *decl : scope->getDecls()) { if (const auto *cst = dyn_cast(decl)) { - lsp::CompletionItem item; + llvm::lsp::CompletionItem item; item.label = cst->getName().getName().str(); - item.kind = lsp::CompletionItemKind::Interface; + item.kind = llvm::lsp::CompletionItemKind::Interface; item.sortText = "2_" + item.label; // Skip constraints that are not single-arg. We currently only @@ -841,8 +859,8 @@ class LSPCodeCompleteContext : public CodeCompleteContext { // Format the documentation for the constraint. if (std::optional doc = getDocumentationFor(sourceMgr, cst)) { - item.documentation = - lsp::MarkupContent{lsp::MarkupKind::Markdown, std::move(*doc)}; + item.documentation = llvm::lsp::MarkupContent{ + llvm::lsp::MarkupKind::Markdown, std::move(*doc)}; } completionList.items.emplace_back(item); @@ -856,10 +874,10 @@ class LSPCodeCompleteContext : public CodeCompleteContext { void codeCompleteDialectName() final { // Code complete known dialects. for (const ods::Dialect &dialect : odsContext.getDialects()) { - lsp::CompletionItem item; + llvm::lsp::CompletionItem item; item.label = dialect.getName().str(); - item.kind = lsp::CompletionItemKind::Class; - item.insertTextFormat = lsp::InsertTextFormat::PlainText; + item.kind = llvm::lsp::CompletionItemKind::Class; + item.insertTextFormat = llvm::lsp::InsertTextFormat::PlainText; completionList.items.emplace_back(item); } } @@ -872,10 +890,10 @@ class LSPCodeCompleteContext : public CodeCompleteContext { for (const auto &it : dialect->getOperations()) { const ods::Operation &op = *it.second; - lsp::CompletionItem item; + llvm::lsp::CompletionItem item; item.label = op.getName().drop_front(dialectName.size() + 1).str(); - item.kind = lsp::CompletionItemKind::Field; - item.insertTextFormat = lsp::InsertTextFormat::PlainText; + item.kind = llvm::lsp::CompletionItemKind::Field; + item.insertTextFormat = llvm::lsp::InsertTextFormat::PlainText; completionList.items.emplace_back(item); } } @@ -883,16 +901,16 @@ class LSPCodeCompleteContext : public CodeCompleteContext { void codeCompletePatternMetadata() final { auto addSimpleConstraint = [&](StringRef constraint, StringRef desc, StringRef snippetText = "") { - lsp::CompletionItem item; + llvm::lsp::CompletionItem item; item.label = constraint.str(); - item.kind = lsp::CompletionItemKind::Class; + item.kind = llvm::lsp::CompletionItemKind::Class; item.detail = "pattern metadata"; item.documentation = - lsp::MarkupContent{lsp::MarkupKind::Markdown, desc.str()}; + llvm::lsp::MarkupContent{llvm::lsp::MarkupKind::Markdown, desc.str()}; item.insertText = snippetText.str(); item.insertTextFormat = snippetText.empty() - ? lsp::InsertTextFormat::PlainText - : lsp::InsertTextFormat::Snippet; + ? llvm::lsp::InsertTextFormat::PlainText + : llvm::lsp::InsertTextFormat::Snippet; completionList.items.emplace_back(item); }; @@ -913,10 +931,10 @@ class LSPCodeCompleteContext : public CodeCompleteContext { // Functor used to add a single include completion item. auto addIncludeCompletion = [&](StringRef path, bool isDirectory) { - lsp::CompletionItem item; + llvm::lsp::CompletionItem item; item.label = path.str(); - item.kind = isDirectory ? lsp::CompletionItemKind::Folder - : lsp::CompletionItemKind::File; + item.kind = isDirectory ? llvm::lsp::CompletionItemKind::Folder + : llvm::lsp::CompletionItemKind::File; if (seenResults.insert(item.label).second) completionList.items.emplace_back(item); }; @@ -961,31 +979,31 @@ class LSPCodeCompleteContext : public CodeCompleteContext { // Sort the completion results to make sure the output is deterministic in // the face of different iteration schemes for different platforms. - llvm::sort(completionList.items, [](const lsp::CompletionItem &lhs, - const lsp::CompletionItem &rhs) { + llvm::sort(completionList.items, [](const llvm::lsp::CompletionItem &lhs, + const llvm::lsp::CompletionItem &rhs) { return lhs.label < rhs.label; }); } private: llvm::SourceMgr &sourceMgr; - lsp::CompletionList &completionList; + llvm::lsp::CompletionList &completionList; ods::Context &odsContext; ArrayRef includeDirs; }; } // namespace -lsp::CompletionList -PDLDocument::getCodeCompletion(const lsp::URIForFile &uri, - const lsp::Position &completePos) { +llvm::lsp::CompletionList +PDLDocument::getCodeCompletion(const llvm::lsp::URIForFile &uri, + const llvm::lsp::Position &completePos) { SMLoc posLoc = completePos.getAsSMLoc(sourceMgr); if (!posLoc.isValid()) - return lsp::CompletionList(); + return llvm::lsp::CompletionList(); // To perform code completion, we run another parse of the module with the // code completion context provided. ods::Context tmpODSContext; - lsp::CompletionList completionList; + llvm::lsp::CompletionList completionList; LSPCodeCompleteContext lspCompleteContext(posLoc, sourceMgr, completionList, tmpODSContext, sourceMgr.getIncludeDirs()); @@ -1005,7 +1023,7 @@ namespace { class LSPSignatureHelpContext : public CodeCompleteContext { public: LSPSignatureHelpContext(SMLoc completeLoc, llvm::SourceMgr &sourceMgr, - lsp::SignatureHelp &signatureHelp, + llvm::lsp::SignatureHelp &signatureHelp, ods::Context &odsContext) : CodeCompleteContext(completeLoc), sourceMgr(sourceMgr), signatureHelp(signatureHelp), odsContext(odsContext) {} @@ -1014,7 +1032,7 @@ class LSPSignatureHelpContext : public CodeCompleteContext { unsigned currentNumArgs) final { signatureHelp.activeParameter = currentNumArgs; - lsp::SignatureInformation signatureInfo; + llvm::lsp::SignatureInformation signatureInfo; { llvm::raw_string_ostream strOS(signatureInfo.label); strOS << callable->getName()->getName() << "("; @@ -1022,7 +1040,7 @@ class LSPSignatureHelpContext : public CodeCompleteContext { unsigned paramStart = strOS.str().size(); strOS << var->getName().getName() << ": " << var->getType(); unsigned paramEnd = strOS.str().size(); - signatureInfo.parameters.emplace_back(lsp::ParameterInformation{ + signatureInfo.parameters.emplace_back(llvm::lsp::ParameterInformation{ StringRef(strOS.str()).slice(paramStart, paramEnd).str(), std::make_pair(paramStart, paramEnd), /*paramDoc*/ std::string()}); }; @@ -1070,7 +1088,7 @@ class LSPSignatureHelpContext : public CodeCompleteContext { // not more than what is defined in ODS, as this will result in an error // anyways. if (odsOp && currentValue < values.size()) { - lsp::SignatureInformation signatureInfo; + llvm::lsp::SignatureInformation signatureInfo; // Build the signature label. { @@ -1099,7 +1117,7 @@ class LSPSignatureHelpContext : public CodeCompleteContext { } unsigned paramEnd = strOS.str().size(); - signatureInfo.parameters.emplace_back(lsp::ParameterInformation{ + signatureInfo.parameters.emplace_back(llvm::lsp::ParameterInformation{ StringRef(strOS.str()).slice(paramStart, paramEnd).str(), std::make_pair(paramStart, paramEnd), paramDoc}); }; @@ -1114,12 +1132,12 @@ class LSPSignatureHelpContext : public CodeCompleteContext { // If there aren't any arguments yet, we also add the generic signature. if (currentValue == 0 && (!odsOp || !values.empty())) { - lsp::SignatureInformation signatureInfo; + llvm::lsp::SignatureInformation signatureInfo; signatureInfo.label = llvm::formatv("(<{0}s>: {1}Range)", label, dataType).str(); signatureInfo.documentation = ("Generic operation " + label + " specification").str(); - signatureInfo.parameters.emplace_back(lsp::ParameterInformation{ + signatureInfo.parameters.emplace_back(llvm::lsp::ParameterInformation{ StringRef(signatureInfo.label).drop_front().drop_back().str(), std::pair(1, signatureInfo.label.size() - 1), ("All of the " + label + "s of the operation.").str()}); @@ -1129,21 +1147,22 @@ class LSPSignatureHelpContext : public CodeCompleteContext { private: llvm::SourceMgr &sourceMgr; - lsp::SignatureHelp &signatureHelp; + llvm::lsp::SignatureHelp &signatureHelp; ods::Context &odsContext; }; } // namespace -lsp::SignatureHelp PDLDocument::getSignatureHelp(const lsp::URIForFile &uri, - const lsp::Position &helpPos) { +llvm::lsp::SignatureHelp +PDLDocument::getSignatureHelp(const llvm::lsp::URIForFile &uri, + const llvm::lsp::Position &helpPos) { SMLoc posLoc = helpPos.getAsSMLoc(sourceMgr); if (!posLoc.isValid()) - return lsp::SignatureHelp(); + return llvm::lsp::SignatureHelp(); // To perform code completion, we run another parse of the module with the // code completion context provided. ods::Context tmpODSContext; - lsp::SignatureHelp signatureHelp; + llvm::lsp::SignatureHelp signatureHelp; LSPSignatureHelpContext completeContext(posLoc, sourceMgr, signatureHelp, tmpODSContext); @@ -1173,9 +1192,9 @@ static bool shouldAddHintFor(const ast::Expr *expr, StringRef name) { return true; } -void PDLDocument::getInlayHints(const lsp::URIForFile &uri, - const lsp::Range &range, - std::vector &inlayHints) { +void PDLDocument::getInlayHints(const llvm::lsp::URIForFile &uri, + const llvm::lsp::Range &range, + std::vector &inlayHints) { if (failed(astModule)) return; SMRange rangeLoc = range.getAsSMRange(sourceMgr); @@ -1198,9 +1217,9 @@ void PDLDocument::getInlayHints(const lsp::URIForFile &uri, }); } -void PDLDocument::getInlayHintsFor(const ast::VariableDecl *decl, - const lsp::URIForFile &uri, - std::vector &inlayHints) { +void PDLDocument::getInlayHintsFor( + const ast::VariableDecl *decl, const llvm::lsp::URIForFile &uri, + std::vector &inlayHints) { // Check to see if the variable has a constraint list, if it does we don't // provide initializer hints. if (!decl->getConstraints().empty()) @@ -1215,8 +1234,8 @@ void PDLDocument::getInlayHintsFor(const ast::VariableDecl *decl, return; } - lsp::InlayHint hint(lsp::InlayHintKind::Type, - lsp::Position(sourceMgr, decl->getLoc().End)); + llvm::lsp::InlayHint hint(llvm::lsp::InlayHintKind::Type, + llvm::lsp::Position(sourceMgr, decl->getLoc().End)); { llvm::raw_string_ostream labelOS(hint.label); labelOS << ": " << decl->getType(); @@ -1225,9 +1244,9 @@ void PDLDocument::getInlayHintsFor(const ast::VariableDecl *decl, inlayHints.emplace_back(std::move(hint)); } -void PDLDocument::getInlayHintsFor(const ast::CallExpr *expr, - const lsp::URIForFile &uri, - std::vector &inlayHints) { +void PDLDocument::getInlayHintsFor( + const ast::CallExpr *expr, const llvm::lsp::URIForFile &uri, + std::vector &inlayHints) { // Try to extract the callable of this call. const auto *callableRef = dyn_cast(expr->getCallableExpr()); const auto *callable = @@ -1242,9 +1261,9 @@ void PDLDocument::getInlayHintsFor(const ast::CallExpr *expr, std::get<1>(it)->getName().getName()); } -void PDLDocument::getInlayHintsFor(const ast::OperationExpr *expr, - const lsp::URIForFile &uri, - std::vector &inlayHints) { +void PDLDocument::getInlayHintsFor( + const ast::OperationExpr *expr, const llvm::lsp::URIForFile &uri, + std::vector &inlayHints) { // Check for ODS information. ast::OperationType opType = dyn_cast(expr->getType()); const auto *odsOp = opType ? opType.getODSOperation() : nullptr; @@ -1290,13 +1309,15 @@ void PDLDocument::getInlayHintsFor(const ast::OperationExpr *expr, "results"); } -void PDLDocument::addParameterHintFor(std::vector &inlayHints, - const ast::Expr *expr, StringRef label) { +void PDLDocument::addParameterHintFor( + std::vector &inlayHints, const ast::Expr *expr, + StringRef label) { if (!shouldAddHintFor(expr, label)) return; - lsp::InlayHint hint(lsp::InlayHintKind::Parameter, - lsp::Position(sourceMgr, expr->getLoc().Start)); + llvm::lsp::InlayHint hint( + llvm::lsp::InlayHintKind::Parameter, + llvm::lsp::Position(sourceMgr, expr->getLoc().Start)); hint.label = (label + ":").str(); hint.paddingRight = true; inlayHints.emplace_back(std::move(hint)); @@ -1342,22 +1363,24 @@ void PDLDocument::getPDLLViewOutput(raw_ostream &os, namespace { /// This class represents a single chunk of an PDL text file. struct PDLTextFileChunk { - PDLTextFileChunk(uint64_t lineOffset, const lsp::URIForFile &uri, + PDLTextFileChunk(uint64_t lineOffset, const llvm::lsp::URIForFile &uri, StringRef contents, const std::vector &extraDirs, - std::vector &diagnostics) + std::vector &diagnostics) : lineOffset(lineOffset), document(uri, contents, extraDirs, diagnostics) {} /// Adjust the line number of the given range to anchor at the beginning of /// the file, instead of the beginning of this chunk. - void adjustLocForChunkOffset(lsp::Range &range) { + void adjustLocForChunkOffset(llvm::lsp::Range &range) { adjustLocForChunkOffset(range.start); adjustLocForChunkOffset(range.end); } /// Adjust the line number of the given position to anchor at the beginning of /// the file, instead of the beginning of this chunk. - void adjustLocForChunkOffset(lsp::Position &pos) { pos.line += lineOffset; } + void adjustLocForChunkOffset(llvm::lsp::Position &pos) { + pos.line += lineOffset; + } /// The line offset of this chunk from the beginning of the file. uint64_t lineOffset; @@ -1374,38 +1397,41 @@ namespace { /// This class represents a text file containing one or more PDL documents. class PDLTextFile { public: - PDLTextFile(const lsp::URIForFile &uri, StringRef fileContents, + PDLTextFile(const llvm::lsp::URIForFile &uri, StringRef fileContents, int64_t version, const std::vector &extraDirs, - std::vector &diagnostics); + std::vector &diagnostics); /// Return the current version of this text file. int64_t getVersion() const { return version; } /// Update the file to the new version using the provided set of content /// changes. Returns failure if the update was unsuccessful. - LogicalResult update(const lsp::URIForFile &uri, int64_t newVersion, - ArrayRef changes, - std::vector &diagnostics); + LogicalResult + update(const llvm::lsp::URIForFile &uri, int64_t newVersion, + ArrayRef changes, + std::vector &diagnostics); //===--------------------------------------------------------------------===// // LSP Queries //===--------------------------------------------------------------------===// - void getLocationsOf(const lsp::URIForFile &uri, lsp::Position defPos, - std::vector &locations); - void findReferencesOf(const lsp::URIForFile &uri, lsp::Position pos, - std::vector &references); - void getDocumentLinks(const lsp::URIForFile &uri, - std::vector &links); - std::optional findHover(const lsp::URIForFile &uri, - lsp::Position hoverPos); - void findDocumentSymbols(std::vector &symbols); - lsp::CompletionList getCodeCompletion(const lsp::URIForFile &uri, - lsp::Position completePos); - lsp::SignatureHelp getSignatureHelp(const lsp::URIForFile &uri, - lsp::Position helpPos); - void getInlayHints(const lsp::URIForFile &uri, lsp::Range range, - std::vector &inlayHints); + void getLocationsOf(const llvm::lsp::URIForFile &uri, + llvm::lsp::Position defPos, + std::vector &locations); + void findReferencesOf(const llvm::lsp::URIForFile &uri, + llvm::lsp::Position pos, + std::vector &references); + void getDocumentLinks(const llvm::lsp::URIForFile &uri, + std::vector &links); + std::optional findHover(const llvm::lsp::URIForFile &uri, + llvm::lsp::Position hoverPos); + void findDocumentSymbols(std::vector &symbols); + llvm::lsp::CompletionList getCodeCompletion(const llvm::lsp::URIForFile &uri, + llvm::lsp::Position completePos); + llvm::lsp::SignatureHelp getSignatureHelp(const llvm::lsp::URIForFile &uri, + llvm::lsp::Position helpPos); + void getInlayHints(const llvm::lsp::URIForFile &uri, llvm::lsp::Range range, + std::vector &inlayHints); lsp::PDLLViewOutputResult getPDLLViewOutput(lsp::PDLLViewOutputKind kind); private: @@ -1413,14 +1439,14 @@ class PDLTextFile { std::vector>::iterator>; /// Initialize the text file from the given file contents. - void initialize(const lsp::URIForFile &uri, int64_t newVersion, - std::vector &diagnostics); + void initialize(const llvm::lsp::URIForFile &uri, int64_t newVersion, + std::vector &diagnostics); /// Find the PDL document that contains the given position, and update the /// position to be anchored at the start of the found chunk instead of the /// beginning of the file. - ChunkIterator getChunkItFor(lsp::Position &pos); - PDLTextFileChunk &getChunkFor(lsp::Position &pos) { + ChunkIterator getChunkItFor(llvm::lsp::Position &pos); + PDLTextFileChunk &getChunkFor(llvm::lsp::Position &pos) { return *getChunkItFor(pos); } @@ -1442,20 +1468,21 @@ class PDLTextFile { }; } // namespace -PDLTextFile::PDLTextFile(const lsp::URIForFile &uri, StringRef fileContents, - int64_t version, +PDLTextFile::PDLTextFile(const llvm::lsp::URIForFile &uri, + StringRef fileContents, int64_t version, const std::vector &extraDirs, - std::vector &diagnostics) + std::vector &diagnostics) : contents(fileContents.str()), extraIncludeDirs(extraDirs) { initialize(uri, version, diagnostics); } LogicalResult -PDLTextFile::update(const lsp::URIForFile &uri, int64_t newVersion, - ArrayRef changes, - std::vector &diagnostics) { - if (failed(lsp::TextDocumentContentChangeEvent::applyTo(changes, contents))) { - lsp::Logger::error("Failed to update contents of {0}", uri.file()); +PDLTextFile::update(const llvm::lsp::URIForFile &uri, int64_t newVersion, + ArrayRef changes, + std::vector &diagnostics) { + if (failed(llvm::lsp::TextDocumentContentChangeEvent::applyTo(changes, + contents))) { + llvm::lsp::Logger::error("Failed to update contents of {0}", uri.file()); return failure(); } @@ -1464,36 +1491,37 @@ PDLTextFile::update(const lsp::URIForFile &uri, int64_t newVersion, return success(); } -void PDLTextFile::getLocationsOf(const lsp::URIForFile &uri, - lsp::Position defPos, - std::vector &locations) { +void PDLTextFile::getLocationsOf(const llvm::lsp::URIForFile &uri, + llvm::lsp::Position defPos, + std::vector &locations) { PDLTextFileChunk &chunk = getChunkFor(defPos); chunk.document.getLocationsOf(uri, defPos, locations); // Adjust any locations within this file for the offset of this chunk. if (chunk.lineOffset == 0) return; - for (lsp::Location &loc : locations) + for (llvm::lsp::Location &loc : locations) if (loc.uri == uri) chunk.adjustLocForChunkOffset(loc.range); } -void PDLTextFile::findReferencesOf(const lsp::URIForFile &uri, - lsp::Position pos, - std::vector &references) { +void PDLTextFile::findReferencesOf( + const llvm::lsp::URIForFile &uri, llvm::lsp::Position pos, + std::vector &references) { PDLTextFileChunk &chunk = getChunkFor(pos); chunk.document.findReferencesOf(uri, pos, references); // Adjust any locations within this file for the offset of this chunk. if (chunk.lineOffset == 0) return; - for (lsp::Location &loc : references) + for (llvm::lsp::Location &loc : references) if (loc.uri == uri) chunk.adjustLocForChunkOffset(loc.range); } -void PDLTextFile::getDocumentLinks(const lsp::URIForFile &uri, - std::vector &links) { +void PDLTextFile::getDocumentLinks( + const llvm::lsp::URIForFile &uri, + std::vector &links) { chunks.front()->document.getDocumentLinks(uri, links); for (const auto &it : llvm::drop_begin(chunks)) { size_t currentNumLinks = links.size(); @@ -1506,10 +1534,12 @@ void PDLTextFile::getDocumentLinks(const lsp::URIForFile &uri, } } -std::optional PDLTextFile::findHover(const lsp::URIForFile &uri, - lsp::Position hoverPos) { +std::optional +PDLTextFile::findHover(const llvm::lsp::URIForFile &uri, + llvm::lsp::Position hoverPos) { PDLTextFileChunk &chunk = getChunkFor(hoverPos); - std::optional hoverInfo = chunk.document.findHover(uri, hoverPos); + std::optional hoverInfo = + chunk.document.findHover(uri, hoverPos); // Adjust any locations within this file for the offset of this chunk. if (chunk.lineOffset != 0 && hoverInfo && hoverInfo->range) @@ -1518,7 +1548,7 @@ std::optional PDLTextFile::findHover(const lsp::URIForFile &uri, } void PDLTextFile::findDocumentSymbols( - std::vector &symbols) { + std::vector &symbols) { if (chunks.size() == 1) return chunks.front()->document.findDocumentSymbols(symbols); @@ -1526,27 +1556,27 @@ void PDLTextFile::findDocumentSymbols( // each chunk. for (unsigned i = 0, e = chunks.size(); i < e; ++i) { PDLTextFileChunk &chunk = *chunks[i]; - lsp::Position startPos(chunk.lineOffset); - lsp::Position endPos((i == e - 1) ? totalNumLines - 1 - : chunks[i + 1]->lineOffset); - lsp::DocumentSymbol symbol("", - lsp::SymbolKind::Namespace, - /*range=*/lsp::Range(startPos, endPos), - /*selectionRange=*/lsp::Range(startPos)); + llvm::lsp::Position startPos(chunk.lineOffset); + llvm::lsp::Position endPos((i == e - 1) ? totalNumLines - 1 + : chunks[i + 1]->lineOffset); + llvm::lsp::DocumentSymbol symbol( + "", llvm::lsp::SymbolKind::Namespace, + /*range=*/llvm::lsp::Range(startPos, endPos), + /*selectionRange=*/llvm::lsp::Range(startPos)); chunk.document.findDocumentSymbols(symbol.children); // Fixup the locations of document symbols within this chunk. if (i != 0) { - SmallVector symbolsToFix; - for (lsp::DocumentSymbol &childSymbol : symbol.children) + SmallVector symbolsToFix; + for (llvm::lsp::DocumentSymbol &childSymbol : symbol.children) symbolsToFix.push_back(&childSymbol); while (!symbolsToFix.empty()) { - lsp::DocumentSymbol *symbol = symbolsToFix.pop_back_val(); + llvm::lsp::DocumentSymbol *symbol = symbolsToFix.pop_back_val(); chunk.adjustLocForChunkOffset(symbol->range); chunk.adjustLocForChunkOffset(symbol->selectionRange); - for (lsp::DocumentSymbol &childSymbol : symbol->children) + for (llvm::lsp::DocumentSymbol &childSymbol : symbol->children) symbolsToFix.push_back(&childSymbol); } } @@ -1556,34 +1586,37 @@ void PDLTextFile::findDocumentSymbols( } } -lsp::CompletionList PDLTextFile::getCodeCompletion(const lsp::URIForFile &uri, - lsp::Position completePos) { +llvm::lsp::CompletionList +PDLTextFile::getCodeCompletion(const llvm::lsp::URIForFile &uri, + llvm::lsp::Position completePos) { PDLTextFileChunk &chunk = getChunkFor(completePos); - lsp::CompletionList completionList = + llvm::lsp::CompletionList completionList = chunk.document.getCodeCompletion(uri, completePos); // Adjust any completion locations. - for (lsp::CompletionItem &item : completionList.items) { + for (llvm::lsp::CompletionItem &item : completionList.items) { if (item.textEdit) chunk.adjustLocForChunkOffset(item.textEdit->range); - for (lsp::TextEdit &edit : item.additionalTextEdits) + for (llvm::lsp::TextEdit &edit : item.additionalTextEdits) chunk.adjustLocForChunkOffset(edit.range); } return completionList; } -lsp::SignatureHelp PDLTextFile::getSignatureHelp(const lsp::URIForFile &uri, - lsp::Position helpPos) { +llvm::lsp::SignatureHelp +PDLTextFile::getSignatureHelp(const llvm::lsp::URIForFile &uri, + llvm::lsp::Position helpPos) { return getChunkFor(helpPos).document.getSignatureHelp(uri, helpPos); } -void PDLTextFile::getInlayHints(const lsp::URIForFile &uri, lsp::Range range, - std::vector &inlayHints) { +void PDLTextFile::getInlayHints(const llvm::lsp::URIForFile &uri, + llvm::lsp::Range range, + std::vector &inlayHints) { auto startIt = getChunkItFor(range.start); auto endIt = getChunkItFor(range.end); // Functor used to get the chunks for a given file, and fixup any locations - auto getHintsForChunk = [&](ChunkIterator chunkIt, lsp::Range range) { + auto getHintsForChunk = [&](ChunkIterator chunkIt, llvm::lsp::Range range) { size_t currentNumHints = inlayHints.size(); chunkIt->document.getInlayHints(uri, range, inlayHints); @@ -1605,15 +1638,16 @@ void PDLTextFile::getInlayHints(const lsp::URIForFile &uri, lsp::Range range, // Otherwise, the range is split between multiple chunks. The first chunk // has the correct range start, but covers the total document. - getHintsForChunk(startIt, lsp::Range(range.start, getNumLines(startIt))); + getHintsForChunk(startIt, + llvm::lsp::Range(range.start, getNumLines(startIt))); // Every chunk in between uses the full document. for (++startIt; startIt != endIt; ++startIt) - getHintsForChunk(startIt, lsp::Range(0, getNumLines(startIt))); + getHintsForChunk(startIt, llvm::lsp::Range(0, getNumLines(startIt))); // The range for the last chunk starts at the beginning of the document, up // through the end of the input range. - getHintsForChunk(startIt, lsp::Range(0, range.end)); + getHintsForChunk(startIt, llvm::lsp::Range(0, range.end)); } lsp::PDLLViewOutputResult @@ -1632,8 +1666,9 @@ PDLTextFile::getPDLLViewOutput(lsp::PDLLViewOutputKind kind) { return result; } -void PDLTextFile::initialize(const lsp::URIForFile &uri, int64_t newVersion, - std::vector &diagnostics) { +void PDLTextFile::initialize(const llvm::lsp::URIForFile &uri, + int64_t newVersion, + std::vector &diagnostics) { version = newVersion; chunks.clear(); @@ -1653,7 +1688,7 @@ void PDLTextFile::initialize(const lsp::URIForFile &uri, int64_t newVersion, // Adjust locations used in diagnostics to account for the offset from the // beginning of the file. - for (lsp::Diagnostic &diag : + for (llvm::lsp::Diagnostic &diag : llvm::drop_begin(diagnostics, currentNumDiags)) { chunk->adjustLocForChunkOffset(diag.range); @@ -1668,14 +1703,15 @@ void PDLTextFile::initialize(const lsp::URIForFile &uri, int64_t newVersion, totalNumLines = lineOffset; } -PDLTextFile::ChunkIterator PDLTextFile::getChunkItFor(lsp::Position &pos) { +PDLTextFile::ChunkIterator +PDLTextFile::getChunkItFor(llvm::lsp::Position &pos) { if (chunks.size() == 1) return chunks.begin(); // Search for the first chunk with a greater line offset, the previous chunk // is the one that contains `pos`. auto it = llvm::upper_bound( - chunks, pos, [](const lsp::Position &pos, const auto &chunk) { + chunks, pos, [](const llvm::lsp::Position &pos, const auto &chunk) { return static_cast(pos.line) < chunk->lineOffset; }); ChunkIterator chunkIt(it == chunks.end() ? (chunks.end() - 1) : --it); @@ -1710,9 +1746,9 @@ lsp::PDLLServer::PDLLServer(const Options &options) : impl(std::make_unique(options)) {} lsp::PDLLServer::~PDLLServer() = default; -void lsp::PDLLServer::addDocument(const URIForFile &uri, StringRef contents, - int64_t version, - std::vector &diagnostics) { +void lsp::PDLLServer::addDocument( + const URIForFile &uri, StringRef contents, int64_t version, + std::vector &diagnostics) { // Build the set of additional include directories. std::vector additionalIncludeDirs = impl->options.extraDirs; const auto &fileInfo = impl->compilationDatabase.getFileInfo(uri.file()); @@ -1724,7 +1760,7 @@ void lsp::PDLLServer::addDocument(const URIForFile &uri, StringRef contents, void lsp::PDLLServer::updateDocument( const URIForFile &uri, ArrayRef changes, - int64_t version, std::vector &diagnostics) { + int64_t version, std::vector &diagnostics) { // Check that we actually have a document for this uri. auto it = impl->files.find(uri.file()); if (it == impl->files.end()) @@ -1746,17 +1782,17 @@ std::optional lsp::PDLLServer::removeDocument(const URIForFile &uri) { return version; } -void lsp::PDLLServer::getLocationsOf(const URIForFile &uri, - const Position &defPos, - std::vector &locations) { +void lsp::PDLLServer::getLocationsOf( + const URIForFile &uri, const Position &defPos, + std::vector &locations) { auto fileIt = impl->files.find(uri.file()); if (fileIt != impl->files.end()) fileIt->second->getLocationsOf(uri, defPos, locations); } -void lsp::PDLLServer::findReferencesOf(const URIForFile &uri, - const Position &pos, - std::vector &references) { +void lsp::PDLLServer::findReferencesOf( + const URIForFile &uri, const Position &pos, + std::vector &references) { auto fileIt = impl->files.find(uri.file()); if (fileIt != impl->files.end()) fileIt->second->findReferencesOf(uri, pos, references); @@ -1769,8 +1805,8 @@ void lsp::PDLLServer::getDocumentLinks( return fileIt->second->getDocumentLinks(uri, documentLinks); } -std::optional lsp::PDLLServer::findHover(const URIForFile &uri, - const Position &hoverPos) { +std::optional +lsp::PDLLServer::findHover(const URIForFile &uri, const Position &hoverPos) { auto fileIt = impl->files.find(uri.file()); if (fileIt != impl->files.end()) return fileIt->second->findHover(uri, hoverPos); @@ -1793,8 +1829,9 @@ lsp::PDLLServer::getCodeCompletion(const URIForFile &uri, return CompletionList(); } -lsp::SignatureHelp lsp::PDLLServer::getSignatureHelp(const URIForFile &uri, - const Position &helpPos) { +llvm::lsp::SignatureHelp +lsp::PDLLServer::getSignatureHelp(const URIForFile &uri, + const Position &helpPos) { auto fileIt = impl->files.find(uri.file()); if (fileIt != impl->files.end()) return fileIt->second->getSignatureHelp(uri, helpPos); diff --git a/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.h b/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.h index 134431fa63bf8..d82014d6b0684 100644 --- a/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.h +++ b/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.h @@ -11,6 +11,7 @@ #include "mlir/Support/LLVM.h" #include "llvm/ADT/StringRef.h" +#include "llvm/Support/LSP/Protocol.h" #include #include #include @@ -18,21 +19,22 @@ namespace mlir { namespace lsp { -struct Diagnostic; +using llvm::lsp::CompletionList; +using llvm::lsp::Diagnostic; +using llvm::lsp::DocumentLink; +using llvm::lsp::DocumentSymbol; +using llvm::lsp::Hover; +using llvm::lsp::InlayHint; +using llvm::lsp::Location; +using llvm::lsp::Position; +using llvm::lsp::Range; +using llvm::lsp::SignatureHelp; +using llvm::lsp::TextDocumentContentChangeEvent; +using llvm::lsp::URIForFile; + class CompilationDatabase; struct PDLLViewOutputResult; enum class PDLLViewOutputKind; -struct CompletionList; -struct DocumentLink; -struct DocumentSymbol; -struct Hover; -struct InlayHint; -struct Location; -struct Position; -struct Range; -struct SignatureHelp; -struct TextDocumentContentChangeEvent; -class URIForFile; /// This class implements all of the PDLL related functionality necessary for a /// language server. This class allows for keeping the PDLL specific logic diff --git a/mlir/lib/Tools/mlir-pdll-lsp-server/Protocol.cpp b/mlir/lib/Tools/mlir-pdll-lsp-server/Protocol.cpp index 0c9896e3ec1b4..ace460536aa1b 100644 --- a/mlir/lib/Tools/mlir-pdll-lsp-server/Protocol.cpp +++ b/mlir/lib/Tools/mlir-pdll-lsp-server/Protocol.cpp @@ -11,6 +11,7 @@ //===----------------------------------------------------------------------===// #include "Protocol.h" +#include "mlir/Support/LLVM.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/JSON.h" diff --git a/mlir/lib/Tools/mlir-pdll-lsp-server/Protocol.h b/mlir/lib/Tools/mlir-pdll-lsp-server/Protocol.h index 0706316631851..a2775f8cbadc2 100644 --- a/mlir/lib/Tools/mlir-pdll-lsp-server/Protocol.h +++ b/mlir/lib/Tools/mlir-pdll-lsp-server/Protocol.h @@ -20,10 +20,12 @@ #ifndef LIB_MLIR_TOOLS_MLIRPDLLLSPSERVER_PROTOCOL_H_ #define LIB_MLIR_TOOLS_MLIRPDLLLSPSERVER_PROTOCOL_H_ -#include "mlir/Tools/lsp-server-support/Protocol.h" +#include "llvm/Support/LSP/Protocol.h" namespace mlir { namespace lsp { +using llvm::lsp::URIForFile; + //===----------------------------------------------------------------------===// // PDLLViewOutputParams //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Tools/tblgen-lsp-server/CMakeLists.txt b/mlir/lib/Tools/tblgen-lsp-server/CMakeLists.txt index 80fc1ffe4029a..b21650ed03b6f 100644 --- a/mlir/lib/Tools/tblgen-lsp-server/CMakeLists.txt +++ b/mlir/lib/Tools/tblgen-lsp-server/CMakeLists.txt @@ -2,6 +2,7 @@ set(LLVM_LINK_COMPONENTS Demangle Support TableGen + SupportLSP ) llvm_add_library(TableGenLspServerLib diff --git a/mlir/lib/Tools/tblgen-lsp-server/LSPServer.cpp b/mlir/lib/Tools/tblgen-lsp-server/LSPServer.cpp index bb3c0a77747aa..95a457f3144c5 100644 --- a/mlir/lib/Tools/tblgen-lsp-server/LSPServer.cpp +++ b/mlir/lib/Tools/tblgen-lsp-server/LSPServer.cpp @@ -9,14 +9,33 @@ #include "LSPServer.h" #include "TableGenServer.h" -#include "mlir/Tools/lsp-server-support/Logging.h" -#include "mlir/Tools/lsp-server-support/Protocol.h" -#include "mlir/Tools/lsp-server-support/Transport.h" +#include "llvm/Support/LSP/Logging.h" +#include "llvm/Support/LSP/Protocol.h" +#include "llvm/Support/LSP/Transport.h" #include using namespace mlir; using namespace mlir::lsp; +using llvm::lsp::Callback; +using llvm::lsp::DidChangeTextDocumentParams; +using llvm::lsp::DidCloseTextDocumentParams; +using llvm::lsp::DidOpenTextDocumentParams; +using llvm::lsp::DocumentLinkParams; +using llvm::lsp::Hover; +using llvm::lsp::InitializedParams; +using llvm::lsp::InitializeParams; +using llvm::lsp::JSONTransport; +using llvm::lsp::Location; +using llvm::lsp::Logger; +using llvm::lsp::MessageHandler; +using llvm::lsp::NoParams; +using llvm::lsp::OutgoingNotification; +using llvm::lsp::PublishDiagnosticsParams; +using llvm::lsp::ReferenceParams; +using llvm::lsp::TextDocumentPositionParams; +using llvm::lsp::TextDocumentSyncKind; + //===----------------------------------------------------------------------===// // LSPServer //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Tools/tblgen-lsp-server/LSPServer.h b/mlir/lib/Tools/tblgen-lsp-server/LSPServer.h index 501a9dada8aab..596688b62f8da 100644 --- a/mlir/lib/Tools/tblgen-lsp-server/LSPServer.h +++ b/mlir/lib/Tools/tblgen-lsp-server/LSPServer.h @@ -13,17 +13,19 @@ namespace llvm { struct LogicalResult; +namespace lsp { +class JSONTransport; +} // namespace lsp } // namespace llvm namespace mlir { namespace lsp { -class JSONTransport; class TableGenServer; /// Run the main loop of the LSP server using the given TableGen server and /// transport. llvm::LogicalResult runTableGenLSPServer(TableGenServer &server, - JSONTransport &transport); + llvm::lsp::JSONTransport &transport); } // namespace lsp } // namespace mlir diff --git a/mlir/lib/Tools/tblgen-lsp-server/TableGenLspServerMain.cpp b/mlir/lib/Tools/tblgen-lsp-server/TableGenLspServerMain.cpp index 21af78c9a506c..8014b8d6dba4a 100644 --- a/mlir/lib/Tools/tblgen-lsp-server/TableGenLspServerMain.cpp +++ b/mlir/lib/Tools/tblgen-lsp-server/TableGenLspServerMain.cpp @@ -9,14 +9,18 @@ #include "mlir/Tools/tblgen-lsp-server/TableGenLspServerMain.h" #include "LSPServer.h" #include "TableGenServer.h" -#include "mlir/Tools/lsp-server-support/Logging.h" -#include "mlir/Tools/lsp-server-support/Transport.h" #include "llvm/Support/CommandLine.h" +#include "llvm/Support/LSP/Logging.h" +#include "llvm/Support/LSP/Transport.h" #include "llvm/Support/Program.h" using namespace mlir; using namespace mlir::lsp; +using llvm::lsp::JSONStreamStyle; +using llvm::lsp::JSONTransport; +using llvm::lsp::Logger; + LogicalResult mlir::TableGenLspServerMain(int argc, char **argv) { llvm::cl::opt inputStyle{ "input-style", diff --git a/mlir/lib/Tools/tblgen-lsp-server/TableGenServer.cpp b/mlir/lib/Tools/tblgen-lsp-server/TableGenServer.cpp index 5faeeae839f44..3080b78f187b1 100644 --- a/mlir/lib/Tools/tblgen-lsp-server/TableGenServer.cpp +++ b/mlir/lib/Tools/tblgen-lsp-server/TableGenServer.cpp @@ -10,12 +10,12 @@ #include "mlir/Support/IndentedOstream.h" #include "mlir/Tools/lsp-server-support/CompilationDatabase.h" -#include "mlir/Tools/lsp-server-support/Logging.h" -#include "mlir/Tools/lsp-server-support/Protocol.h" #include "mlir/Tools/lsp-server-support/SourceMgrUtils.h" #include "llvm/ADT/IntervalMap.h" #include "llvm/ADT/PointerUnion.h" #include "llvm/ADT/StringMap.h" +#include "llvm/Support/LSP/Logging.h" +#include "llvm/Support/LSP/Protocol.h" #include "llvm/Support/Path.h" #include "llvm/TableGen/Parser.h" #include "llvm/TableGen/Record.h" @@ -36,45 +36,49 @@ static SMRange convertTokenLocToRange(SMLoc loc) { /// Returns a language server uri for the given source location. `mainFileURI` /// corresponds to the uri for the main file of the source manager. -static lsp::URIForFile getURIFromLoc(const SourceMgr &mgr, SMLoc loc, - const lsp::URIForFile &mainFileURI) { +static llvm::lsp::URIForFile +getURIFromLoc(const SourceMgr &mgr, SMLoc loc, + const llvm::lsp::URIForFile &mainFileURI) { int bufferId = mgr.FindBufferContainingLoc(loc); if (bufferId == 0 || bufferId == static_cast(mgr.getMainFileID())) return mainFileURI; - llvm::Expected fileForLoc = lsp::URIForFile::fromFile( - mgr.getBufferInfo(bufferId).Buffer->getBufferIdentifier()); + llvm::Expected fileForLoc = + llvm::lsp::URIForFile::fromFile( + mgr.getBufferInfo(bufferId).Buffer->getBufferIdentifier()); if (fileForLoc) return *fileForLoc; - lsp::Logger::error("Failed to create URI for include file: {0}", - llvm::toString(fileForLoc.takeError())); + llvm::lsp::Logger::error("Failed to create URI for include file: {0}", + llvm::toString(fileForLoc.takeError())); return mainFileURI; } /// Returns a language server location from the given source range. -static lsp::Location getLocationFromLoc(SourceMgr &mgr, SMRange loc, - const lsp::URIForFile &uri) { - return lsp::Location(getURIFromLoc(mgr, loc.Start, uri), - lsp::Range(mgr, loc)); +static llvm::lsp::Location +getLocationFromLoc(SourceMgr &mgr, SMRange loc, + const llvm::lsp::URIForFile &uri) { + return llvm::lsp::Location(getURIFromLoc(mgr, loc.Start, uri), + llvm::lsp::Range(mgr, loc)); } -static lsp::Location getLocationFromLoc(SourceMgr &mgr, SMLoc loc, - const lsp::URIForFile &uri) { +static llvm::lsp::Location +getLocationFromLoc(SourceMgr &mgr, SMLoc loc, + const llvm::lsp::URIForFile &uri) { return getLocationFromLoc(mgr, convertTokenLocToRange(loc), uri); } /// Convert the given TableGen diagnostic to the LSP form. -static std::optional +static std::optional getLspDiagnoticFromDiag(const llvm::SMDiagnostic &diag, - const lsp::URIForFile &uri) { + const llvm::lsp::URIForFile &uri) { auto *sourceMgr = const_cast(diag.getSourceMgr()); if (!sourceMgr || !diag.getLoc().isValid()) return std::nullopt; - lsp::Diagnostic lspDiag; + llvm::lsp::Diagnostic lspDiag; lspDiag.source = "tablegen"; lspDiag.category = "Parse Error"; // Try to grab a file location for this diagnostic. - lsp::Location loc = getLocationFromLoc(*sourceMgr, diag.getLoc(), uri); + llvm::lsp::Location loc = getLocationFromLoc(*sourceMgr, diag.getLoc(), uri); lspDiag.range = loc.range; // Skip diagnostics that weren't emitted within the main file. @@ -84,17 +88,17 @@ getLspDiagnoticFromDiag(const llvm::SMDiagnostic &diag, // Convert the severity for the diagnostic. switch (diag.getKind()) { case SourceMgr::DK_Warning: - lspDiag.severity = lsp::DiagnosticSeverity::Warning; + lspDiag.severity = llvm::lsp::DiagnosticSeverity::Warning; break; case SourceMgr::DK_Error: - lspDiag.severity = lsp::DiagnosticSeverity::Error; + lspDiag.severity = llvm::lsp::DiagnosticSeverity::Error; break; case SourceMgr::DK_Note: // Notes are emitted separately from the main diagnostic, so we just treat // them as remarks given that we can't determine the diagnostic to relate // them to. case SourceMgr::DK_Remark: - lspDiag.severity = lsp::DiagnosticSeverity::Information; + lspDiag.severity = llvm::lsp::DiagnosticSeverity::Information; break; } lspDiag.message = diag.getMessage().str(); @@ -322,54 +326,59 @@ namespace { /// This class represents a text file containing one or more TableGen documents. class TableGenTextFile { public: - TableGenTextFile(const lsp::URIForFile &uri, StringRef fileContents, + TableGenTextFile(const llvm::lsp::URIForFile &uri, StringRef fileContents, int64_t version, const std::vector &extraIncludeDirs, - std::vector &diagnostics); + std::vector &diagnostics); /// Return the current version of this text file. int64_t getVersion() const { return version; } /// Update the file to the new version using the provided set of content /// changes. Returns failure if the update was unsuccessful. - LogicalResult update(const lsp::URIForFile &uri, int64_t newVersion, - ArrayRef changes, - std::vector &diagnostics); + LogicalResult + update(const llvm::lsp::URIForFile &uri, int64_t newVersion, + ArrayRef changes, + std::vector &diagnostics); //===--------------------------------------------------------------------===// // Definitions and References //===--------------------------------------------------------------------===// - void getLocationsOf(const lsp::URIForFile &uri, const lsp::Position &defPos, - std::vector &locations); - void findReferencesOf(const lsp::URIForFile &uri, const lsp::Position &pos, - std::vector &references); + void getLocationsOf(const llvm::lsp::URIForFile &uri, + const llvm::lsp::Position &defPos, + std::vector &locations); + void findReferencesOf(const llvm::lsp::URIForFile &uri, + const llvm::lsp::Position &pos, + std::vector &references); //===--------------------------------------------------------------------===// // Document Links //===--------------------------------------------------------------------===// - void getDocumentLinks(const lsp::URIForFile &uri, - std::vector &links); + void getDocumentLinks(const llvm::lsp::URIForFile &uri, + std::vector &links); //===--------------------------------------------------------------------===// // Hover //===--------------------------------------------------------------------===// - std::optional findHover(const lsp::URIForFile &uri, - const lsp::Position &hoverPos); - lsp::Hover buildHoverForRecord(const Record *record, - const SMRange &hoverRange); - lsp::Hover buildHoverForTemplateArg(const Record *record, + std::optional + findHover(const llvm::lsp::URIForFile &uri, + const llvm::lsp::Position &hoverPos); + llvm::lsp::Hover buildHoverForRecord(const Record *record, + const SMRange &hoverRange); + llvm::lsp::Hover buildHoverForTemplateArg(const Record *record, + const RecordVal *value, + const SMRange &hoverRange); + llvm::lsp::Hover buildHoverForField(const Record *record, const RecordVal *value, const SMRange &hoverRange); - lsp::Hover buildHoverForField(const Record *record, const RecordVal *value, - const SMRange &hoverRange); private: /// Initialize the text file from the given file contents. - void initialize(const lsp::URIForFile &uri, int64_t newVersion, - std::vector &diagnostics); + void initialize(const llvm::lsp::URIForFile &uri, int64_t newVersion, + std::vector &diagnostics); /// The full string contents of the file. std::string contents; @@ -395,9 +404,9 @@ class TableGenTextFile { } // namespace TableGenTextFile::TableGenTextFile( - const lsp::URIForFile &uri, StringRef fileContents, int64_t version, + const llvm::lsp::URIForFile &uri, StringRef fileContents, int64_t version, const std::vector &extraIncludeDirs, - std::vector &diagnostics) + std::vector &diagnostics) : contents(fileContents.str()), version(version) { // Build the set of include directories for this file. llvm::SmallString<32> uriDirectory(uri.file()); @@ -409,12 +418,13 @@ TableGenTextFile::TableGenTextFile( initialize(uri, version, diagnostics); } -LogicalResult -TableGenTextFile::update(const lsp::URIForFile &uri, int64_t newVersion, - ArrayRef changes, - std::vector &diagnostics) { - if (failed(lsp::TextDocumentContentChangeEvent::applyTo(changes, contents))) { - lsp::Logger::error("Failed to update contents of {0}", uri.file()); +LogicalResult TableGenTextFile::update( + const llvm::lsp::URIForFile &uri, int64_t newVersion, + ArrayRef changes, + std::vector &diagnostics) { + if (failed(llvm::lsp::TextDocumentContentChangeEvent::applyTo(changes, + contents))) { + llvm::lsp::Logger::error("Failed to update contents of {0}", uri.file()); return failure(); } @@ -423,9 +433,9 @@ TableGenTextFile::update(const lsp::URIForFile &uri, int64_t newVersion, return success(); } -void TableGenTextFile::initialize(const lsp::URIForFile &uri, - int64_t newVersion, - std::vector &diagnostics) { +void TableGenTextFile::initialize( + const llvm::lsp::URIForFile &uri, int64_t newVersion, + std::vector &diagnostics) { version = newVersion; sourceMgr = SourceMgr(); recordKeeper = std::make_unique(); @@ -433,7 +443,8 @@ void TableGenTextFile::initialize(const lsp::URIForFile &uri, // Build a buffer for this file. auto memBuffer = llvm::MemoryBuffer::getMemBuffer(contents, uri.file()); if (!memBuffer) { - lsp::Logger::error("Failed to create memory buffer for file", uri.file()); + llvm::lsp::Logger::error("Failed to create memory buffer for file", + uri.file()); return; } sourceMgr.setIncludeDirs(includeDirs); @@ -442,8 +453,8 @@ void TableGenTextFile::initialize(const lsp::URIForFile &uri, // This class provides a context argument for the SourceMgr diagnostic // handler. struct DiagHandlerContext { - std::vector &diagnostics; - const lsp::URIForFile &uri; + std::vector &diagnostics; + const llvm::lsp::URIForFile &uri; } handlerContext{diagnostics, uri}; // Set the diagnostic handler for the tablegen source manager. @@ -469,9 +480,9 @@ void TableGenTextFile::initialize(const lsp::URIForFile &uri, // TableGenTextFile: Definitions and References //===----------------------------------------------------------------------===// -void TableGenTextFile::getLocationsOf(const lsp::URIForFile &uri, - const lsp::Position &defPos, - std::vector &locations) { +void TableGenTextFile::getLocationsOf( + const llvm::lsp::URIForFile &uri, const llvm::lsp::Position &defPos, + std::vector &locations) { SMLoc posLoc = defPos.getAsSMLoc(sourceMgr); const TableGenIndexSymbol *symbol = index.lookup(posLoc); if (!symbol) @@ -492,8 +503,8 @@ void TableGenTextFile::getLocationsOf(const lsp::URIForFile &uri, } void TableGenTextFile::findReferencesOf( - const lsp::URIForFile &uri, const lsp::Position &pos, - std::vector &references) { + const llvm::lsp::URIForFile &uri, const llvm::lsp::Position &pos, + std::vector &references) { SMLoc posLoc = pos.getAsSMLoc(sourceMgr); const TableGenIndexSymbol *symbol = index.lookup(posLoc); if (!symbol) @@ -508,8 +519,9 @@ void TableGenTextFile::findReferencesOf( // TableGenTextFile: Document Links //===--------------------------------------------------------------------===// -void TableGenTextFile::getDocumentLinks(const lsp::URIForFile &uri, - std::vector &links) { +void TableGenTextFile::getDocumentLinks( + const llvm::lsp::URIForFile &uri, + std::vector &links) { for (const lsp::SourceMgrInclude &include : parsedIncludes) links.emplace_back(include.range, include.uri); } @@ -518,9 +530,9 @@ void TableGenTextFile::getDocumentLinks(const lsp::URIForFile &uri, // TableGenTextFile: Hover //===----------------------------------------------------------------------===// -std::optional -TableGenTextFile::findHover(const lsp::URIForFile &uri, - const lsp::Position &hoverPos) { +std::optional +TableGenTextFile::findHover(const llvm::lsp::URIForFile &uri, + const llvm::lsp::Position &hoverPos) { // Check for a reference to an include. for (const lsp::SourceMgrInclude &include : parsedIncludes) if (include.range.contains(hoverPos)) @@ -546,9 +558,10 @@ TableGenTextFile::findHover(const lsp::URIForFile &uri, return buildHoverForField(recordVal->record, value, hoverRange); } -lsp::Hover TableGenTextFile::buildHoverForRecord(const Record *record, - const SMRange &hoverRange) { - lsp::Hover hover(lsp::Range(sourceMgr, hoverRange)); +llvm::lsp::Hover +TableGenTextFile::buildHoverForRecord(const Record *record, + const SMRange &hoverRange) { + llvm::lsp::Hover hover(llvm::lsp::Range(sourceMgr, hoverRange)); { llvm::raw_string_ostream hoverOS(hover.contents.value); @@ -590,9 +603,9 @@ lsp::Hover TableGenTextFile::buildHoverForRecord(const Record *record, return hover; } -lsp::Hover TableGenTextFile::buildHoverForTemplateArg( +llvm::lsp::Hover TableGenTextFile::buildHoverForTemplateArg( const Record *record, const RecordVal *value, const SMRange &hoverRange) { - lsp::Hover hover(lsp::Range(sourceMgr, hoverRange)); + llvm::lsp::Hover hover(llvm::lsp::Range(sourceMgr, hoverRange)); { llvm::raw_string_ostream hoverOS(hover.contents.value); StringRef name = value->getName().rsplit(':').second; @@ -604,10 +617,9 @@ lsp::Hover TableGenTextFile::buildHoverForTemplateArg( return hover; } -lsp::Hover TableGenTextFile::buildHoverForField(const Record *record, - const RecordVal *value, - const SMRange &hoverRange) { - lsp::Hover hover(lsp::Range(sourceMgr, hoverRange)); +llvm::lsp::Hover TableGenTextFile::buildHoverForField( + const Record *record, const RecordVal *value, const SMRange &hoverRange) { + llvm::lsp::Hover hover(llvm::lsp::Range(sourceMgr, hoverRange)); { llvm::raw_string_ostream hoverOS(hover.contents.value); hoverOS << "**field** `" << value->getName() << "`\n***\nType: `"; @@ -722,7 +734,7 @@ void lsp::TableGenServer::getDocumentLinks( return fileIt->second->getDocumentLinks(uri, documentLinks); } -std::optional +std::optional lsp::TableGenServer::findHover(const URIForFile &uri, const Position &hoverPos) { auto fileIt = impl->files.find(uri.file()); diff --git a/mlir/lib/Tools/tblgen-lsp-server/TableGenServer.h b/mlir/lib/Tools/tblgen-lsp-server/TableGenServer.h index bdc851024a818..e54b8bcf35e24 100644 --- a/mlir/lib/Tools/tblgen-lsp-server/TableGenServer.h +++ b/mlir/lib/Tools/tblgen-lsp-server/TableGenServer.h @@ -11,6 +11,7 @@ #include "mlir/Support/LLVM.h" #include "llvm/ADT/StringRef.h" +#include "llvm/Support/LSP/Protocol.h" #include #include #include @@ -18,13 +19,13 @@ namespace mlir { namespace lsp { -struct Diagnostic; -struct DocumentLink; -struct Hover; -struct Location; -struct Position; -struct TextDocumentContentChangeEvent; -class URIForFile; +using llvm::lsp::Diagnostic; +using llvm::lsp::DocumentLink; +using llvm::lsp::Hover; +using llvm::lsp::Location; +using llvm::lsp::Position; +using llvm::lsp::TextDocumentContentChangeEvent; +using llvm::lsp::URIForFile; /// This class implements all of the TableGen related functionality necessary /// for a language server. This class allows for keeping the TableGen specific diff --git a/mlir/tools/mlir-lsp-server/mlir-lsp-server.cpp b/mlir/tools/mlir-lsp-server/mlir-lsp-server.cpp index 10d602fdfe728..712237bbbbca6 100644 --- a/mlir/tools/mlir-lsp-server/mlir-lsp-server.cpp +++ b/mlir/tools/mlir-lsp-server/mlir-lsp-server.cpp @@ -10,8 +10,8 @@ #include "mlir/IR/MLIRContext.h" #include "mlir/InitAllDialects.h" #include "mlir/InitAllExtensions.h" -#include "mlir/Tools/lsp-server-support/Protocol.h" #include "mlir/Tools/mlir-lsp-server/MlirLspServerMain.h" +#include "llvm/Support/LSP/Protocol.h" using namespace mlir; @@ -37,8 +37,8 @@ int main(int argc, char **argv) { // Returns the registry, except in testing mode when the URI contains // "-disable-lsp-registration". Testing for/example of registering dialects // based on URI. - auto registryFn = [®istry, - &empty](const lsp::URIForFile &uri) -> DialectRegistry & { + auto registryFn = [®istry, &empty]( + const llvm::lsp::URIForFile &uri) -> DialectRegistry & { (void)empty; #ifdef MLIR_INCLUDE_TESTS if (uri.uri().contains("-disable-lsp-registration")) diff --git a/mlir/unittests/CMakeLists.txt b/mlir/unittests/CMakeLists.txt index c5f0d7e384d01..89332bce5fe05 100644 --- a/mlir/unittests/CMakeLists.txt +++ b/mlir/unittests/CMakeLists.txt @@ -18,7 +18,6 @@ add_subdirectory(Support) add_subdirectory(Rewrite) add_subdirectory(TableGen) add_subdirectory(Target) -add_subdirectory(Tools) add_subdirectory(Transforms) if(MLIR_ENABLE_EXECUTION_ENGINE) diff --git a/mlir/unittests/Tools/CMakeLists.txt b/mlir/unittests/Tools/CMakeLists.txt deleted file mode 100644 index a97588d928668..0000000000000 --- a/mlir/unittests/Tools/CMakeLists.txt +++ /dev/null @@ -1 +0,0 @@ -add_subdirectory(lsp-server-support) diff --git a/mlir/unittests/Tools/lsp-server-support/CMakeLists.txt b/mlir/unittests/Tools/lsp-server-support/CMakeLists.txt deleted file mode 100644 index c539c9bc5101f..0000000000000 --- a/mlir/unittests/Tools/lsp-server-support/CMakeLists.txt +++ /dev/null @@ -1,7 +0,0 @@ -add_mlir_unittest(MLIRLspServerSupportTests - Protocol.cpp - Transport.cpp -) -mlir_target_link_libraries(MLIRLspServerSupportTests - PRIVATE - MLIRLspServerSupportLib)