diff --git a/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp b/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp index ed9a30086deca..577959bbdbeaa 100644 --- a/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp @@ -15,6 +15,8 @@ #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" #include "mlir/Interfaces/CallInterfaces.h" +#include "llvm/Support/InterleavedRange.h" + #include "SPIRVOpUtils.h" #include "SPIRVParsingUtils.h" @@ -119,12 +121,9 @@ ParseResult BranchConditionalOp::parse(OpAsmParser &parser, void BranchConditionalOp::print(OpAsmPrinter &printer) { printer << ' ' << getCondition(); - if (auto weights = getBranchWeights()) { - printer << " ["; - llvm::interleaveComma(weights->getValue(), printer, [&](Attribute a) { - printer << llvm::cast(a).getInt(); - }); - printer << "]"; + if (std::optional weights = getBranchWeights()) { + printer << ' ' + << llvm::interleaved_array(weights->getAsValueRange()); } printer << ", "; diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVAttributes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVAttributes.cpp index b71be23fdf47d..2ba6106896c1f 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVAttributes.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVAttributes.cpp @@ -12,6 +12,7 @@ #include "mlir/IR/Builders.h" #include "mlir/IR/DialectImplementation.h" #include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/InterleavedRange.h" using namespace mlir; using namespace mlir::spirv; @@ -621,17 +622,14 @@ Attribute SPIRVDialect::parseAttribute(DialectAsmParser &parser, //===----------------------------------------------------------------------===// static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer) { - auto &os = printer.getStream(); printer << spirv::VerCapExtAttr::getKindName() << "<" - << spirv::stringifyVersion(triple.getVersion()) << ", ["; - llvm::interleaveComma( - triple.getCapabilities(), os, - [&](spirv::Capability cap) { os << spirv::stringifyCapability(cap); }); - printer << "], ["; - llvm::interleaveComma(triple.getExtensionsAttr(), os, [&](Attribute attr) { - os << llvm::cast(attr).getValue(); - }); - printer << "]>"; + << spirv::stringifyVersion(triple.getVersion()) << ", " + << llvm::interleaved_array(llvm::map_range( + triple.getCapabilities(), spirv::stringifyCapability)) + << ", " + << llvm::interleaved_array( + triple.getExtensionsAttr().getAsValueRange()) + << ">"; } static void print(spirv::TargetEnvAttr targetEnv, DialectAsmPrinter &printer) { diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp index 16e91b0cb2cfc..097a18f8a70eb 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -35,6 +35,7 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/InterleavedRange.h" #include #include #include @@ -807,10 +808,8 @@ void spirv::EntryPointOp::print(OpAsmPrinter &printer) { printer << " \"" << stringifyExecutionModel(getExecutionModel()) << "\" "; printer.printSymbolName(getFn()); auto interfaceVars = getInterface().getValue(); - if (!interfaceVars.empty()) { - printer << ", "; - llvm::interleaveComma(interfaceVars, printer); - } + if (!interfaceVars.empty()) + printer << ", " << llvm::interleaved(interfaceVars); } LogicalResult spirv::EntryPointOp::verify() { @@ -862,13 +861,9 @@ void spirv::ExecutionModeOp::print(OpAsmPrinter &printer) { printer << " "; printer.printSymbolName(getFn()); printer << " \"" << stringifyExecutionMode(getExecutionMode()) << "\""; - auto values = this->getValues(); - if (values.empty()) - return; - printer << ", "; - llvm::interleaveComma(values, printer, [&](Attribute a) { - printer << llvm::cast(a).getInt(); - }); + ArrayAttr values = this->getValues(); + if (!values.empty()) + printer << ", " << llvm::interleaved(values.getAsValueRange()); } //===----------------------------------------------------------------------===// @@ -1824,13 +1819,8 @@ ParseResult spirv::SpecConstantCompositeOp::parse(OpAsmParser &parser, void spirv::SpecConstantCompositeOp::print(OpAsmPrinter &printer) { printer << " "; printer.printSymbolName(getSymName()); - printer << " ("; - auto constituents = this->getConstituents().getValue(); - - if (!constituents.empty()) - llvm::interleaveComma(constituents, printer); - - printer << ") : " << getType(); + printer << " (" << llvm::interleaved(this->getConstituents().getValue()) + << ") : " << getType(); } LogicalResult spirv::SpecConstantCompositeOp::verify() {