diff --git a/mlir/include/mlir/Conversion/CMakeLists.txt b/mlir/include/mlir/Conversion/CMakeLists.txt index d212bf3e395e7..9f76ab659215e 100644 --- a/mlir/include/mlir/Conversion/CMakeLists.txt +++ b/mlir/include/mlir/Conversion/CMakeLists.txt @@ -6,3 +6,5 @@ mlir_tablegen(Passes.capi.cpp.inc -gen-pass-capi-impl --prefix Conversion) add_public_tablegen_target(MLIRConversionPassIncGen) add_mlir_doc(Passes ConversionPasses ./ -gen-pass-doc) + +add_subdirectory(ConvertToLLVM) diff --git a/mlir/include/mlir/Conversion/ConvertToLLVM/CMakeLists.txt b/mlir/include/mlir/Conversion/ConvertToLLVM/CMakeLists.txt new file mode 100644 index 0000000000000..54d7a03fc22df --- /dev/null +++ b/mlir/include/mlir/Conversion/ConvertToLLVM/CMakeLists.txt @@ -0,0 +1,7 @@ +set(LLVM_TARGET_DEFINITIONS ToLLVMInterface.td) +mlir_tablegen(ToLLVMAttrInterface.h.inc -gen-attr-interface-decls) +mlir_tablegen(ToLLVMAttrInterface.cpp.inc -gen-attr-interface-defs) +mlir_tablegen(ToLLVMOpInterface.h.inc -gen-op-interface-decls) +mlir_tablegen(ToLLVMOpInterface.cpp.inc -gen-op-interface-defs) +add_public_tablegen_target(MLIRConvertToLLVMInterfaceIncGen) +add_dependencies(mlir-generic-headers MLIRConvertToLLVMInterfaceIncGen) diff --git a/mlir/include/mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h b/mlir/include/mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h index 00aeed9bf29dc..6fd043646acd3 100644 --- a/mlir/include/mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h +++ b/mlir/include/mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h @@ -11,6 +11,7 @@ #include "mlir/IR/DialectInterface.h" #include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OpDefinition.h" namespace mlir { class ConversionTarget; @@ -18,6 +19,7 @@ class LLVMTypeConverter; class MLIRContext; class Operation; class RewritePatternSet; +class AnalysisManager; /// Base class for dialect interfaces providing translation to LLVM IR. /// Dialects that can be translated should provide an implementation of this @@ -50,6 +52,18 @@ void populateConversionTargetFromOperation(Operation *op, LLVMTypeConverter &typeConverter, RewritePatternSet &patterns); +/// Helper function for populating LLVM conversion patterns. If `op` implements +/// the `ConvertToLLVMOpInterface` interface, then the LLVM conversion pattern +/// attributes provided by the interface will be used to configure the +/// conversion target, type converter, and the pattern set. +void populateOpConvertToLLVMConversionPatterns(Operation *op, + ConversionTarget &target, + LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns); } // namespace mlir +#include "mlir/Conversion/ConvertToLLVM/ToLLVMAttrInterface.h.inc" + +#include "mlir/Conversion/ConvertToLLVM/ToLLVMOpInterface.h.inc" + #endif // MLIR_CONVERSION_CONVERTTOLLVM_TOLLVMINTERFACE_H diff --git a/mlir/include/mlir/Conversion/ConvertToLLVM/ToLLVMInterface.td b/mlir/include/mlir/Conversion/ConvertToLLVM/ToLLVMInterface.td new file mode 100644 index 0000000000000..1331a9802c570 --- /dev/null +++ b/mlir/include/mlir/Conversion/ConvertToLLVM/ToLLVMInterface.td @@ -0,0 +1,76 @@ + +//===- ToLLVMInterface.td - Conversion to LLVM interfaces -----*- tablegen -*-===// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Defines interfaces for managing transformations, including populating +// pattern rewrites. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_CONVERTTOLLVM_TOLLVMINTERFACE_TD +#define MLIR_CONVERSION_CONVERTTOLLVM_TOLLVMINTERFACE_TD + +include "mlir/IR/OpBase.td" + +//===----------------------------------------------------------------------===// +// Attribute interface +//===----------------------------------------------------------------------===// + +def ConvertToLLVMAttrInterface : + AttrInterface<"ConvertToLLVMAttrInterface"> { + let description = [{ + The `ConvertToLLVMAttrInterface` attribute interfaces allows using + attributes to configure the convert to LLVM infrastructure, this includes: + - The conversion target. + - The LLVM type converter. + - The pattern set. + + This interface permits fined grained configuration of the `convert-to-llvm` + process. For example, attributes with target information like + `#nvvm.target` or `#rodcl.target` can leverage this interface for populating + patterns specific to a particular target. + }]; + let cppNamespace = "::mlir"; + let methods = [ + InterfaceMethod< + /*desc=*/[{ + Populate the dialect conversion target, type converter and pattern set. + }], + /*retTy=*/"void", + /*methodName=*/"populateConvertToLLVMConversionPatterns", + /*args=*/(ins "::mlir::ConversionTarget&":$target, + "::mlir::LLVMTypeConverter&":$typeConverter, + "::mlir::RewritePatternSet&":$patternSet)> + ]; +} + +//===----------------------------------------------------------------------===// +// Op interface +//===----------------------------------------------------------------------===// + +def ConvertToLLVMOpInterface : OpInterface<"ConvertToLLVMOpInterface"> { + let description = [{ + Interface for collecting all convert to LLVM attributes stored in an + operation. See `ConvertToLLVMAttrInterface` for more information on these + attributes. + }]; + let cppNamespace = "::mlir"; + let methods = [ + InterfaceMethod< + /*desc=*/[{ + Populate the provided vector with a list of convert to LLVM attributes + to apply. + }], + /*retTy=*/"void", + /*methodName=*/"getConvertToLLVMConversionAttrs", + /*args=*/(ins + "::llvm::SmallVectorImpl<::mlir::ConvertToLLVMAttrInterface>&":$attrs) + > + ]; +} + +#endif // MLIR_CONVERSION_CONVERTTOLLVM_TOLLVMINTERFACE_TD diff --git a/mlir/include/mlir/Conversion/GPUCommon/GPUToLLVM.h b/mlir/include/mlir/Conversion/GPUCommon/GPUToLLVM.h new file mode 100644 index 0000000000000..ad8c39fe67661 --- /dev/null +++ b/mlir/include/mlir/Conversion/GPUCommon/GPUToLLVM.h @@ -0,0 +1,25 @@ +//===- GPUToLLVM.h - Convert GPU to LLVM dialect ----------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This files declares registration functions for converting GPU to LLVM. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_GPUCOMMON_GPUTOLLVM_H +#define MLIR_CONVERSION_GPUCOMMON_GPUTOLLVM_H + +namespace mlir { +class DialectRegistry; +namespace gpu { +/// Registers the `ConvertToLLVMOpInterface` interface on the `gpu::GPUModuleOP` +/// operation. +void registerConvertGpuToLLVMInterface(DialectRegistry ®istry); +} // namespace gpu +} // namespace mlir + +#endif // MLIR_CONVERSION_GPUCOMMON_GPUTOLLVM_H diff --git a/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVM.h b/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVM.h new file mode 100644 index 0000000000000..6311630a23c8f --- /dev/null +++ b/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVM.h @@ -0,0 +1,27 @@ +//===- GPUToNVVM.h - Convert GPU to NVVM dialect ----------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This files declares registration functions for converting GPU to NVVM. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_GPUTONVVM_GPUTONVVM_H +#define MLIR_CONVERSION_GPUTONVVM_GPUTONVVM_H + +namespace mlir { +class DialectRegistry; +namespace NVVM { +/// Registers the `ConvertToLLVMAttrInterface` interface on the +/// `NVVM::NVVMTargetAttr` attribute. This interface populates the conversion +/// target, LLVM type converter, and pattern set for converting GPU operations +/// to NVVM. +void registerConvertGpuToNVVMInterface(DialectRegistry ®istry); +} // namespace NVVM +} // namespace mlir + +#endif // MLIR_CONVERSION_GPUTONVVM_GPUTONVVM_H diff --git a/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h b/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h index 645e86a430962..fc7c967f1b62c 100644 --- a/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h +++ b/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h @@ -31,6 +31,10 @@ LLVM::LLVMStructType convertMMAToLLVMType(gpu::MMAMatrixType type); /// Configure target to convert from the GPU dialect to NVVM. void configureGpuToNVVMConversionLegality(ConversionTarget &target); +/// Configure the LLVM type convert to convert types and address spaces from the +/// GPU dialect to NVVM. +void configureGpuToNVVMTypeConverter(LLVMTypeConverter &converter); + /// Collect a set of patterns to convert from the GPU dialect to NVVM. void populateGpuToNVVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns); diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index 4d272ba219c6f..e394bae64e091 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -22,12 +22,20 @@ def ConvertToLLVMPass : Pass<"convert-to-llvm"> { This is a generic pass to convert to LLVM, it uses the `ConvertToLLVMPatternInterface` dialect interface to delegate to dialects the injection of conversion patterns. + + If `dynamic` is set to `true`, the pass will look for + `ConvertToLLVMAttrInterface` attributes and use them to further configure + the conversion process. This option also uses the `DataLayoutAnalysis` + analysis to configure the type converter. Enabling this option incurs in + extra overhead. }]; let constructor = "mlir::createConvertToLLVMPass()"; let options = [ ListOption<"filterDialects", "filter-dialects", "std::string", "Test conversion patterns of only the specified dialects">, + Option<"useDynamic", "dynamic", "bool", "false", + "Use op conversion attributes to configure the conversion">, ]; } diff --git a/mlir/include/mlir/InitAllExtensions.h b/mlir/include/mlir/InitAllExtensions.h index 1f2ef26b45070..14a6a2787b3a5 100644 --- a/mlir/include/mlir/InitAllExtensions.h +++ b/mlir/include/mlir/InitAllExtensions.h @@ -18,6 +18,8 @@ #include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h" #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" +#include "mlir/Conversion/GPUCommon/GPUToLLVM.h" +#include "mlir/Conversion/GPUToNVVM/GPUToNVVM.h" #include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h" #include "mlir/Conversion/MathToLLVM/MathToLLVM.h" #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" @@ -72,6 +74,8 @@ inline void registerAllExtensions(DialectRegistry ®istry) { registerConvertOpenMPToLLVMInterface(registry); ub::registerConvertUBToLLVMInterface(registry); registerConvertAMXToLLVMInterface(registry); + gpu::registerConvertGpuToLLVMInterface(registry); + NVVM::registerConvertGpuToNVVMInterface(registry); // Register all transform dialect extensions. affine::registerTransformDialectExtension(registry); diff --git a/mlir/lib/Conversion/ConvertToLLVM/CMakeLists.txt b/mlir/lib/Conversion/ConvertToLLVM/CMakeLists.txt index de3d850d520c0..c71711ba2ebed 100644 --- a/mlir/lib/Conversion/ConvertToLLVM/CMakeLists.txt +++ b/mlir/lib/Conversion/ConvertToLLVM/CMakeLists.txt @@ -7,6 +7,7 @@ add_mlir_conversion_library(MLIRConvertToLLVMInterface ToLLVMInterface.cpp DEPENDS + MLIRConvertToLLVMInterfaceIncGen LINK_LIBS PUBLIC MLIRIR @@ -21,6 +22,7 @@ add_mlir_conversion_library(MLIRConvertToLLVMPass LINK_LIBS PUBLIC MLIRIR + MLIRConvertToLLVMInterface MLIRLLVMCommonConversion MLIRLLVMDialect MLIRPass diff --git a/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp b/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp index b2407a258c271..673ba814d338f 100644 --- a/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp +++ b/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp @@ -6,6 +6,7 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Analysis/DataLayoutAnalysis.h" #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" #include "mlir/Conversion/ConvertToLLVM/ToLLVMPass.h" #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" @@ -27,6 +28,41 @@ namespace mlir { using namespace mlir; namespace { +/// Base class for creating the internal implementation of `convert-to-llvm` +/// passes. +class ConvertToLLVMPassInterface { +public: + ConvertToLLVMPassInterface(MLIRContext *context, + ArrayRef filterDialects); + virtual ~ConvertToLLVMPassInterface() = default; + + /// Get the dependent dialects used by `convert-to-llvm`. + static void getDependentDialects(DialectRegistry ®istry); + + /// Initialize the internal state of the `convert-to-llvm` pass + /// implementation. This method is invoked by `ConvertToLLVMPass::initialize`. + /// This method returns whether the initialization process failed. + virtual LogicalResult initialize() = 0; + + /// Transform `op` to LLVM with the conversions available in the pass. The + /// analysis manager can be used to query analyzes like `DataLayoutAnalysis` + /// to further configure the conversion process. This method is invoked by + /// `ConvertToLLVMPass::runOnOperation`. This method returns whether the + /// transformation process failed. + virtual LogicalResult transform(Operation *op, + AnalysisManager manager) const = 0; + +protected: + /// Visit the `ConvertToLLVMPatternInterface` dialect interfaces and call + /// `visitor` with each of the interfaces. If `filterDialects` is non-empty, + /// then `visitor` is invoked only with the dialects in the `filterDialects` + /// list. + LogicalResult visitInterfaces( + llvm::function_ref visitor); + MLIRContext *context; + /// List of dialects names to use as filters. + ArrayRef filterDialects; +}; /// This DialectExtension can be attached to the context, which will invoke the /// `apply()` method for every loaded dialect. If a dialect implements the @@ -58,74 +94,188 @@ class LoadDependentDialectExtension : public DialectExtensionBase { } }; +//===----------------------------------------------------------------------===// +// StaticConvertToLLVM +//===----------------------------------------------------------------------===// + +/// Static implementation of the `convert-to-llvm` pass. This version only looks +/// at dialect interfaces to configure the conversion process. +struct StaticConvertToLLVM : public ConvertToLLVMPassInterface { + /// Pattern set with conversions to LLVM. + std::shared_ptr patterns; + /// The conversion target. + std::shared_ptr target; + /// The LLVM type converter. + std::shared_ptr typeConverter; + using ConvertToLLVMPassInterface::ConvertToLLVMPassInterface; + + /// Configure the conversion to LLVM at pass initialization. + LogicalResult initialize() final { + auto target = std::make_shared(*context); + auto typeConverter = std::make_shared(context); + RewritePatternSet tempPatterns(context); + target->addLegalDialect(); + // Populate the patterns with the dialect interface. + if (failed(visitInterfaces([&](ConvertToLLVMPatternInterface *iface) { + iface->populateConvertToLLVMConversionPatterns( + *target, *typeConverter, tempPatterns); + }))) + return failure(); + this->patterns = + std::make_unique(std::move(tempPatterns)); + this->target = target; + this->typeConverter = typeConverter; + return success(); + } + + /// Apply the conversion driver. + LogicalResult transform(Operation *op, AnalysisManager manager) const final { + if (failed(applyPartialConversion(op, *target, *patterns))) + return failure(); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// DynamicConvertToLLVM +//===----------------------------------------------------------------------===// + +/// Dynamic implementation of the `convert-to-llvm` pass. This version inspects +/// the IR to configure the conversion to LLVM. +struct DynamicConvertToLLVM : public ConvertToLLVMPassInterface { + /// A list of all the `ConvertToLLVMPatternInterface` dialect interfaces used + /// to partially configure the conversion process. + std::shared_ptr> + interfaces; + using ConvertToLLVMPassInterface::ConvertToLLVMPassInterface; + + /// Collect the dialect interfaces used to configure the conversion process. + LogicalResult initialize() final { + auto interfaces = + std::make_shared>(); + // Collect the interfaces. + if (failed(visitInterfaces([&](ConvertToLLVMPatternInterface *iface) { + interfaces->push_back(iface); + }))) + return failure(); + this->interfaces = interfaces; + return success(); + } + + /// Configure the conversion process and apply the conversion driver. + LogicalResult transform(Operation *op, AnalysisManager manager) const final { + RewritePatternSet patterns(context); + ConversionTarget target(*context); + target.addLegalDialect(); + // Get the data layout analysis. + const auto &dlAnalysis = manager.getAnalysis(); + LLVMTypeConverter typeConverter(context, &dlAnalysis); + + // Configure the conversion with dialect level interfaces. + for (ConvertToLLVMPatternInterface *iface : *interfaces) + iface->populateConvertToLLVMConversionPatterns(target, typeConverter, + patterns); + + // Configure the conversion attribute interfaces. + populateOpConvertToLLVMConversionPatterns(op, target, typeConverter, + patterns); + + // Apply the conversion. + if (failed(applyPartialConversion(op, target, std::move(patterns)))) + return failure(); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// ConvertToLLVMPass +//===----------------------------------------------------------------------===// + /// This is a generic pass to convert to LLVM, it uses the /// `ConvertToLLVMPatternInterface` dialect interface to delegate to dialects /// the injection of conversion patterns. class ConvertToLLVMPass : public impl::ConvertToLLVMPassBase { - std::shared_ptr patterns; - std::shared_ptr target; - std::shared_ptr typeConverter; + std::shared_ptr impl; public: using impl::ConvertToLLVMPassBase::ConvertToLLVMPassBase; void getDependentDialects(DialectRegistry ®istry) const final { - registry.insert(); - registry.addExtensions(); + ConvertToLLVMPassInterface::getDependentDialects(registry); } LogicalResult initialize(MLIRContext *context) final { - RewritePatternSet tempPatterns(context); - auto target = std::make_shared(*context); - target->addLegalDialect(); - auto typeConverter = std::make_shared(context); - - if (!filterDialects.empty()) { - // Test mode: Populate only patterns from the specified dialects. Produce - // an error if the dialect is not loaded or does not implement the - // interface. - for (std::string &dialectName : filterDialects) { - Dialect *dialect = context->getLoadedDialect(dialectName); - if (!dialect) - return emitError(UnknownLoc::get(context)) - << "dialect not loaded: " << dialectName << "\n"; - auto *iface = dyn_cast(dialect); - if (!iface) - return emitError(UnknownLoc::get(context)) - << "dialect does not implement ConvertToLLVMPatternInterface: " - << dialectName << "\n"; - iface->populateConvertToLLVMConversionPatterns(*target, *typeConverter, - tempPatterns); - } - } else { - // Normal mode: Populate all patterns from all dialects that implement the - // interface. - for (Dialect *dialect : context->getLoadedDialects()) { - // First time we encounter this dialect: if it implements the interface, - // let's populate patterns ! - auto *iface = dyn_cast(dialect); - if (!iface) - continue; - iface->populateConvertToLLVMConversionPatterns(*target, *typeConverter, - tempPatterns); - } - } - - this->patterns = - std::make_unique(std::move(tempPatterns)); - this->target = target; - this->typeConverter = typeConverter; + std::shared_ptr impl; + // Choose the pass implementation. + if (useDynamic) + impl = std::make_shared(context, filterDialects); + else + impl = std::make_shared(context, filterDialects); + if (failed(impl->initialize())) + return failure(); + this->impl = impl; return success(); } void runOnOperation() final { - if (failed(applyPartialConversion(getOperation(), *target, *patterns))) - signalPassFailure(); + if (failed(impl->transform(getOperation(), getAnalysisManager()))) + return signalPassFailure(); } }; } // namespace +//===----------------------------------------------------------------------===// +// ConvertToLLVMPassInterface +//===----------------------------------------------------------------------===// + +ConvertToLLVMPassInterface::ConvertToLLVMPassInterface( + MLIRContext *context, ArrayRef filterDialects) + : context(context), filterDialects(filterDialects) {} + +void ConvertToLLVMPassInterface::getDependentDialects( + DialectRegistry ®istry) { + registry.insert(); + registry.addExtensions(); +} + +LogicalResult ConvertToLLVMPassInterface::visitInterfaces( + llvm::function_ref visitor) { + if (!filterDialects.empty()) { + // Test mode: Populate only patterns from the specified dialects. Produce + // an error if the dialect is not loaded or does not implement the + // interface. + for (StringRef dialectName : filterDialects) { + Dialect *dialect = context->getLoadedDialect(dialectName); + if (!dialect) + return emitError(UnknownLoc::get(context)) + << "dialect not loaded: " << dialectName << "\n"; + auto *iface = dyn_cast(dialect); + if (!iface) + return emitError(UnknownLoc::get(context)) + << "dialect does not implement ConvertToLLVMPatternInterface: " + << dialectName << "\n"; + visitor(iface); + } + } else { + // Normal mode: Populate all patterns from all dialects that implement the + // interface. + for (Dialect *dialect : context->getLoadedDialects()) { + // First time we encounter this dialect: if it implements the interface, + // let's populate patterns ! + auto *iface = dyn_cast(dialect); + if (!iface) + continue; + visitor(iface); + } + } + return success(); +} + +//===----------------------------------------------------------------------===// +// API +//===----------------------------------------------------------------------===// + void mlir::registerConvertToLLVMDependentDialectLoading( DialectRegistry ®istry) { registry.addExtensions(); diff --git a/mlir/lib/Conversion/ConvertToLLVM/ToLLVMInterface.cpp b/mlir/lib/Conversion/ConvertToLLVM/ToLLVMInterface.cpp index 3a4e83b2a8838..252245dfbf541 100644 --- a/mlir/lib/Conversion/ConvertToLLVM/ToLLVMInterface.cpp +++ b/mlir/lib/Conversion/ConvertToLLVM/ToLLVMInterface.cpp @@ -30,3 +30,22 @@ void mlir::populateConversionTargetFromOperation( patterns); }); } + +void mlir::populateOpConvertToLLVMConversionPatterns( + Operation *op, ConversionTarget &target, LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns) { + auto iface = dyn_cast(op); + if (!iface) + iface = op->getParentOfType(); + if (!iface) + return; + SmallVector attrs; + iface.getConvertToLLVMConversionAttrs(attrs); + for (ConvertToLLVMAttrInterface attr : attrs) + attr.populateConvertToLLVMConversionPatterns(target, typeConverter, + patterns); +} + +#include "mlir/Conversion/ConvertToLLVM/ToLLVMAttrInterface.cpp.inc" + +#include "mlir/Conversion/ConvertToLLVM/ToLLVMOpInterface.cpp.inc" diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp index 92b28ff9c5873..1497d662dcdbd 100644 --- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp +++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp @@ -22,6 +22,7 @@ #include "mlir/Conversion/ConvertToLLVM/ToLLVMPass.h" #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h" +#include "mlir/Conversion/GPUCommon/GPUToLLVM.h" #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" @@ -1762,3 +1763,34 @@ void mlir::populateGpuToLLVMConversionPatterns(LLVMTypeConverter &converter, ConvertSetCsrPointersOpToGpuRuntimeCallPattern>(converter); patterns.add(converter, kernelBarePtrCallConv); } + +//===----------------------------------------------------------------------===// +// GPUModuleOp convert to LLVM op interface +//===----------------------------------------------------------------------===// + +namespace { +struct GPUModuleOpConvertToLLVMInterface + : public ConvertToLLVMOpInterface::ExternalModel< + GPUModuleOpConvertToLLVMInterface, gpu::GPUModuleOp> { + /// Get the conversion patterns from the target attribute. + void getConvertToLLVMConversionAttrs( + Operation *op, SmallVectorImpl &attrs) const; +}; +} // namespace + +void GPUModuleOpConvertToLLVMInterface::getConvertToLLVMConversionAttrs( + Operation *op, SmallVectorImpl &attrs) const { + auto module = cast(op); + ArrayAttr targetsAttr = module.getTargetsAttr(); + // Fail if there are no target attributes or there is more than one target. + if (!targetsAttr || targetsAttr.size() != 1) + return; + if (auto patternAttr = dyn_cast(targetsAttr[0])) + attrs.push_back(patternAttr); +} + +void mlir::gpu::registerConvertGpuToLLVMInterface(DialectRegistry ®istry) { + registry.addExtension(+[](MLIRContext *ctx, gpu::GPUDialect *dialect) { + gpu::GPUModuleOp::attachInterface(*ctx); + }); +} diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp index 04e85c2b337de..b343cf71e3a2e 100644 --- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -15,8 +15,10 @@ #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" +#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" #include "mlir/Conversion/GPUCommon/GPUCommonPass.h" +#include "mlir/Conversion/GPUToNVVM/GPUToNVVM.h" #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" #include "mlir/Conversion/LLVMCommon/LoweringOptions.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h" @@ -274,29 +276,7 @@ struct LowerGpuOpsToNVVMOpsPass } LLVMTypeConverter converter(m.getContext(), options); - // NVVM uses alloca in the default address space to represent private - // memory allocations, so drop private annotations. NVVM uses address - // space 3 for shared memory. NVVM uses the default address space to - // represent global memory. - populateGpuMemorySpaceAttributeConversions( - converter, [](gpu::AddressSpace space) -> unsigned { - switch (space) { - case gpu::AddressSpace::Global: - return static_cast( - NVVM::NVVMMemorySpace::kGlobalMemorySpace); - case gpu::AddressSpace::Workgroup: - return static_cast( - NVVM::NVVMMemorySpace::kSharedMemorySpace); - case gpu::AddressSpace::Private: - return 0; - } - llvm_unreachable("unknown address space enum value"); - return 0; - }); - // Lowering for MMAMatrixType. - converter.addConversion([&](gpu::MMAMatrixType type) -> Type { - return convertMMAToLLVMType(type); - }); + configureGpuToNVVMTypeConverter(converter); RewritePatternSet llvmPatterns(m.getContext()); arith::populateArithToLLVMConversionPatterns(converter, llvmPatterns); @@ -332,6 +312,32 @@ void mlir::configureGpuToNVVMConversionLegality(ConversionTarget &target) { target.addLegalOp(); } +void mlir::configureGpuToNVVMTypeConverter(LLVMTypeConverter &converter) { + // NVVM uses alloca in the default address space to represent private + // memory allocations, so drop private annotations. NVVM uses address + // space 3 for shared memory. NVVM uses the default address space to + // represent global memory. + populateGpuMemorySpaceAttributeConversions( + converter, [](gpu::AddressSpace space) -> unsigned { + switch (space) { + case gpu::AddressSpace::Global: + return static_cast( + NVVM::NVVMMemorySpace::kGlobalMemorySpace); + case gpu::AddressSpace::Workgroup: + return static_cast( + NVVM::NVVMMemorySpace::kSharedMemorySpace); + case gpu::AddressSpace::Private: + return 0; + } + llvm_unreachable("unknown address space enum value"); + return 0; + }); + // Lowering for MMAMatrixType. + converter.addConversion([&](gpu::MMAMatrixType type) -> Type { + return convertMMAToLLVMType(type); + }); +} + template static void populateOpPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns, StringRef f32Func, @@ -467,3 +473,34 @@ void mlir::populateGpuToNVVMConversionPatterns( populateOpPatterns(converter, patterns, "__nv_tanhf", "__nv_tanh"); } + +//===----------------------------------------------------------------------===// +// NVVMTargetAttr convert to LLVM attr interface +//===----------------------------------------------------------------------===// + +namespace { +struct NVVMTargetConvertToLLVMAttrInterface + : public ConvertToLLVMAttrInterface::ExternalModel< + NVVMTargetConvertToLLVMAttrInterface, NVVM::NVVMTargetAttr> { + /// Configure GPU to NVVM. + void populateConvertToLLVMConversionPatterns( + Attribute attr, ConversionTarget &target, + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) const; +}; +} // namespace + +void NVVMTargetConvertToLLVMAttrInterface:: + populateConvertToLLVMConversionPatterns(Attribute attr, + ConversionTarget &target, + LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns) const { + configureGpuToNVVMConversionLegality(target); + configureGpuToNVVMTypeConverter(typeConverter); + populateGpuToNVVMConversionPatterns(typeConverter, patterns); +} + +void mlir::NVVM::registerConvertGpuToNVVMInterface(DialectRegistry ®istry) { + registry.addExtension(+[](MLIRContext *ctx, NVVMDialect *dialect) { + NVVMTargetAttr::attachInterface(*ctx); + }); +} diff --git a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm-target-attr.mlir b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm-target-attr.mlir new file mode 100644 index 0000000000000..ed7fa6508d5ad --- /dev/null +++ b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm-target-attr.mlir @@ -0,0 +1,42 @@ +// RUN: mlir-opt %s --pass-pipeline="builtin.module(gpu.module(convert-to-llvm{dynamic=true}))" | FileCheck %s + +// CHECK-LABEL: gpu.module @nvvm_module +gpu.module @nvvm_module [#nvvm.target] { + // CHECK-LABEL: llvm.func @kernel_0() + func.func @kernel_0() -> index { + // CHECK: = nvvm.read.ptx.sreg.tid.x : i32 + // CHECK: = llvm.sext %{{.*}} : i32 to i64 + %tIdX = gpu.thread_id x + // CHECK: = nvvm.read.ptx.sreg.laneid range : i32 + // CHECK: = llvm.sext %{{.*}} : i32 to i64 + %laneId = gpu.lane_id + %sum = index.add %tIdX, %laneId + func.return %sum : index + } + +// CHECK-LABEL: llvm.func @kernel_1 +// CHECK: (%{{.*}}: !llvm.ptr<1>, %arg1: !llvm.ptr<1>, %arg2: i64) +// CHECK: attributes {gpu.kernel, gpu.known_block_size = array, nvvm.kernel, nvvm.maxntid = array} + gpu.func @kernel_1(%arg0 : memref>) kernel attributes {known_block_size = array} { + gpu.return + } +} + +// CHECK-LABEL: gpu.module @nvvm_module_2 +gpu.module @nvvm_module_2 { + // CHECK-LABEL: llvm.func @kernel_0() + func.func @kernel_0() -> index { + // CHECK: = gpu.thread_id x + %tIdX = gpu.thread_id x + // CHECK: = gpu.lane_id + %laneId = gpu.lane_id + %sum = index.add %tIdX, %laneId + func.return %sum : index + } + +// CHECK-LABEL: gpu.func @kernel_1 +// CHECK: (%{{.*}}: memref>) kernel attributes {known_block_size = array} + gpu.func @kernel_1(%arg0 : memref>) kernel attributes {known_block_size = array} { + gpu.return + } +}