From 34edc4061c91d47db0cd4a159bc510331c1a3dbd Mon Sep 17 00:00:00 2001 From: Jakub Kuderski Date: Thu, 17 Apr 2025 22:09:10 -0400 Subject: [PATCH 1/2] [mlir][spirv] Switch to `llvm::interleaved`. NFC. Clean up printing code by switching to `llvm::interleaved` from https://github.com/llvm/llvm-project/pull/135517. --- mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp | 11 ++++----- mlir/lib/Dialect/SPIRV/IR/SPIRVAttributes.cpp | 18 +++++++-------- mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp | 23 ++++++------------- 3 files changed, 20 insertions(+), 32 deletions(-) diff --git a/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp b/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp index ed9a30086deca..577959bbdbeaa 100644 --- a/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp @@ -15,6 +15,8 @@ #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" #include "mlir/Interfaces/CallInterfaces.h" +#include "llvm/Support/InterleavedRange.h" + #include "SPIRVOpUtils.h" #include "SPIRVParsingUtils.h" @@ -119,12 +121,9 @@ ParseResult BranchConditionalOp::parse(OpAsmParser &parser, void BranchConditionalOp::print(OpAsmPrinter &printer) { printer << ' ' << getCondition(); - if (auto weights = getBranchWeights()) { - printer << " ["; - llvm::interleaveComma(weights->getValue(), printer, [&](Attribute a) { - printer << llvm::cast(a).getInt(); - }); - printer << "]"; + if (std::optional weights = getBranchWeights()) { + printer << ' ' + << llvm::interleaved_array(weights->getAsValueRange()); } printer << ", "; diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVAttributes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVAttributes.cpp index b71be23fdf47d..2ba6106896c1f 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVAttributes.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVAttributes.cpp @@ -12,6 +12,7 @@ #include "mlir/IR/Builders.h" #include "mlir/IR/DialectImplementation.h" #include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/InterleavedRange.h" using namespace mlir; using namespace mlir::spirv; @@ -621,17 +622,14 @@ Attribute SPIRVDialect::parseAttribute(DialectAsmParser &parser, //===----------------------------------------------------------------------===// static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer) { - auto &os = printer.getStream(); printer << spirv::VerCapExtAttr::getKindName() << "<" - << spirv::stringifyVersion(triple.getVersion()) << ", ["; - llvm::interleaveComma( - triple.getCapabilities(), os, - [&](spirv::Capability cap) { os << spirv::stringifyCapability(cap); }); - printer << "], ["; - llvm::interleaveComma(triple.getExtensionsAttr(), os, [&](Attribute attr) { - os << llvm::cast(attr).getValue(); - }); - printer << "]>"; + << spirv::stringifyVersion(triple.getVersion()) << ", " + << llvm::interleaved_array(llvm::map_range( + triple.getCapabilities(), spirv::stringifyCapability)) + << ", " + << llvm::interleaved_array( + triple.getExtensionsAttr().getAsValueRange()) + << ">"; } static void print(spirv::TargetEnvAttr targetEnv, DialectAsmPrinter &printer) { diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp index 16e91b0cb2cfc..448b1cf578bd7 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -35,6 +35,7 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/InterleavedRange.h" #include #include #include @@ -808,8 +809,7 @@ void spirv::EntryPointOp::print(OpAsmPrinter &printer) { printer.printSymbolName(getFn()); auto interfaceVars = getInterface().getValue(); if (!interfaceVars.empty()) { - printer << ", "; - llvm::interleaveComma(interfaceVars, printer); + printer << ", " << llvm::interleaved(interfaceVars); } } @@ -862,13 +862,9 @@ void spirv::ExecutionModeOp::print(OpAsmPrinter &printer) { printer << " "; printer.printSymbolName(getFn()); printer << " \"" << stringifyExecutionMode(getExecutionMode()) << "\""; - auto values = this->getValues(); - if (values.empty()) - return; - printer << ", "; - llvm::interleaveComma(values, printer, [&](Attribute a) { - printer << llvm::cast(a).getInt(); - }); + ArrayAttr values = this->getValues(); + if (!values.empty()) + printer << ", " << llvm::interleaved(values.getAsValueRange()); } //===----------------------------------------------------------------------===// @@ -1824,13 +1820,8 @@ ParseResult spirv::SpecConstantCompositeOp::parse(OpAsmParser &parser, void spirv::SpecConstantCompositeOp::print(OpAsmPrinter &printer) { printer << " "; printer.printSymbolName(getSymName()); - printer << " ("; - auto constituents = this->getConstituents().getValue(); - - if (!constituents.empty()) - llvm::interleaveComma(constituents, printer); - - printer << ") : " << getType(); + printer << " (" << llvm::interleaved(this->getConstituents().getValue()) + << ") : " << getType(); } LogicalResult spirv::SpecConstantCompositeOp::verify() { From 767ef4c1459816d3a7a32e29975cc7c6681d42af Mon Sep 17 00:00:00 2001 From: Jakub Kuderski Date: Fri, 18 Apr 2025 11:07:24 -0400 Subject: [PATCH 2/2] Address comments --- mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp index 448b1cf578bd7..097a18f8a70eb 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -808,9 +808,8 @@ void spirv::EntryPointOp::print(OpAsmPrinter &printer) { printer << " \"" << stringifyExecutionModel(getExecutionModel()) << "\" "; printer.printSymbolName(getFn()); auto interfaceVars = getInterface().getValue(); - if (!interfaceVars.empty()) { + if (!interfaceVars.empty()) printer << ", " << llvm::interleaved(interfaceVars); - } } LogicalResult spirv::EntryPointOp::verify() {