Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions mlir/include/mlir/Dialect/CommonFolders.h
Original file line number Diff line number Diff line change
Expand Up @@ -196,8 +196,7 @@ template <class AttrElementT,
function_ref<std::optional<ElementValueT>(ElementValueT)>>
Attribute constFoldUnaryOpConditional(ArrayRef<Attribute> operands,
CalculationT &&calculate) {
assert(operands.size() == 1 && "unary op takes one operands");
if (!operands[0])
if (!llvm::getSingleElement(operands))
return {};

static_assert(
Expand Down Expand Up @@ -268,8 +267,7 @@ template <
class CalculationT = function_ref<TargetElementValueT(ElementValueT, bool)>>
Attribute constFoldCastOp(ArrayRef<Attribute> operands, Type resType,
CalculationT &&calculate) {
assert(operands.size() == 1 && "Cast op takes one operand");
if (!operands[0])
if (!llvm::getSingleElement(operands))
return {};

static_assert(
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Analysis/SliceAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ static void getBackwardSliceImpl(Operation *op,
// into us. For now, just bail.
if (parentOp && backwardSlice->count(parentOp) == 0) {
assert(parentOp->getNumRegions() == 1 &&
parentOp->getRegion(0).getBlocks().size() == 1);
llvm::hasSingleElement(parentOp->getRegion(0).getBlocks()));
getBackwardSliceImpl(parentOp, backwardSlice, options);
}
} else {
Expand Down
3 changes: 1 addition & 2 deletions mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -834,8 +834,7 @@ struct TypeCastingOpPattern final : public OpConversionPattern<Op> {
LogicalResult
matchAndRewrite(Op op, typename Op::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
assert(adaptor.getOperands().size() == 1);
Type srcType = adaptor.getOperands().front().getType();
Type srcType = llvm::getSingleElement(adaptor.getOperands()).getType();
Type dstType = this->getTypeConverter()->convertType(op.getType());
if (!dstType)
return getTypeConversionFailure(rewriter, op);
Expand Down
6 changes: 2 additions & 4 deletions mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,7 @@ struct WmmaConstantOpToSPIRVLowering final
LogicalResult
matchAndRewrite(gpu::SubgroupMmaConstantMatrixOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
assert(adaptor.getOperands().size() == 1);
Value cst = adaptor.getOperands().front();
Value cst = llvm::getSingleElement(adaptor.getOperands());
auto coopType = getTypeConverter()->convertType(op.getType());
if (!coopType)
return rewriter.notifyMatchFailure(op, "type conversion failed");
Expand Down Expand Up @@ -181,8 +180,7 @@ struct WmmaElementwiseOpToSPIRVScalarMulLowering final
"splat is not a composite construct");
}

assert(cc.getConstituents().size() == 1);
scalar = cc.getConstituents().front();
scalar = llvm::getSingleElement(cc.getConstituents());

auto coopType = getTypeConverter()->convertType(op.getType());
if (!coopType)
Expand Down
6 changes: 2 additions & 4 deletions mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -419,13 +419,11 @@ struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> {
SmallVector<Value> dynDims, dynDevice;
for (auto dim : adaptor.getDimsDynamic()) {
// type conversion should be 1:1 for ints
assert(dim.size() == 1);
dynDims.emplace_back(dim[0]);
dynDims.emplace_back(llvm::getSingleElement(dim));
}
// same for device
for (auto device : adaptor.getDeviceDynamic()) {
assert(device.size() == 1);
dynDevice.emplace_back(device[0]);
dynDevice.emplace_back(llvm::getSingleElement(device));
}

// To keep the code simple, convert dims/device to values when they are
Expand Down
6 changes: 2 additions & 4 deletions mlir/lib/Dialect/Affine/IR/AffineOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1236,8 +1236,7 @@ mlir::affine::makeComposedFoldedAffineApply(OpBuilder &b, Location loc,
}

applyOp->erase();
assert(foldResults.size() == 1 && "expected 1 folded result");
return foldResults.front();
return llvm::getSingleElement(foldResults);
}

OpFoldResult
Expand Down Expand Up @@ -1306,8 +1305,7 @@ static OpFoldResult makeComposedFoldedMinMax(OpBuilder &b, Location loc,
}

minMaxOp->erase();
assert(foldResults.size() == 1 && "expected 1 folded result");
return foldResults.front();
return llvm::getSingleElement(foldResults);
}

OpFoldResult
Expand Down
3 changes: 1 addition & 2 deletions mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1249,8 +1249,7 @@ struct GreedyFusion {
SmallVector<Operation *, 2> sibLoadOpInsts;
sibNode->getLoadOpsForMemref(memref, &sibLoadOpInsts);
// Currently findSiblingNodeToFuse searches for siblings with one load.
assert(sibLoadOpInsts.size() == 1);
Operation *sibLoadOpInst = sibLoadOpInsts[0];
Operation *sibLoadOpInst = llvm::getSingleElement(sibLoadOpInsts);

// Gather 'dstNode' load ops to 'memref'.
SmallVector<Operation *, 2> dstLoadOpInsts;
Expand Down
6 changes: 2 additions & 4 deletions mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1604,10 +1604,8 @@ SmallVector<AffineForOp, 8> mlir::affine::tile(ArrayRef<AffineForOp> forOps,
ArrayRef<uint64_t> sizes,
AffineForOp target) {
SmallVector<AffineForOp, 8> res;
for (auto loops : tile(forOps, sizes, ArrayRef<AffineForOp>(target))) {
assert(loops.size() == 1);
res.push_back(loops[0]);
}
for (auto loops : tile(forOps, sizes, ArrayRef<AffineForOp>(target)))
res.push_back(llvm::getSingleElement(loops));
return res;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,30 +44,27 @@ struct LinalgCopyOpInterface
linalg::CopyOp> {
OpOperand &getSourceOperand(Operation *op) const {
auto copyOp = cast<CopyOp>(op);
assert(copyOp.getInputs().size() == 1 && "expected single input");
return copyOp.getInputsMutable()[0];
return llvm::getSingleElement(copyOp.getInputsMutable());
}

bool
isEquivalentSubset(Operation *op, Value candidate,
function_ref<bool(Value, Value)> equivalenceFn) const {
auto copyOp = cast<CopyOp>(op);
assert(copyOp.getOutputs().size() == 1 && "expected single output");
return equivalenceFn(candidate, copyOp.getOutputs()[0]);
return equivalenceFn(candidate,
llvm::getSingleElement(copyOp.getOutputs()));
}

Value buildSubsetExtraction(Operation *op, OpBuilder &builder,
Location loc) const {
auto copyOp = cast<CopyOp>(op);
assert(copyOp.getOutputs().size() == 1 && "expected single output");
return copyOp.getOutputs()[0];
return llvm::getSingleElement(copyOp.getOutputs());
}

SmallVector<Value>
getValuesNeededToBuildSubsetExtraction(Operation *op) const {
auto copyOp = cast<CopyOp>(op);
assert(copyOp.getOutputs().size() == 1 && "expected single output");
return {copyOp.getOutputs()[0]};
return {llvm::getSingleElement(copyOp.getOutputs())};
}
};
} // namespace
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -471,7 +471,7 @@ static bool isOpItselfPotentialAutomaticAllocation(Operation *op) {
/// extending the lifetime of allocations.
static bool lastNonTerminatorInRegion(Operation *op) {
return op->getNextNode() == op->getBlock()->getTerminator() &&
op->getParentRegion()->getBlocks().size() == 1;
llvm::hasSingleElement(op->getParentRegion()->getBlocks());
}

/// Inline an AllocaScopeOp if either the direct parent is an allocation scope
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ class QuantizedTypeConverter : public TypeConverter {

static Value materializeConversion(OpBuilder &builder, Type type,
ValueRange inputs, Location loc) {
assert(inputs.size() == 1);
return builder.create<quant::StorageCastOp>(loc, type, inputs[0]);
return builder.create<quant::StorageCastOp>(loc, type,
llvm::getSingleElement(inputs));
}

public:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ static Value castBuffer(OpBuilder &b, Value buffer, Type type) {
static bool doesNotAliasExternalValue(Value value, Region *region,
ValueRange exceptions,
const OneShotAnalysisState &state) {
assert(region->getBlocks().size() == 1 &&
assert(llvm::hasSingleElement(region->getBlocks()) &&
"expected region with single block");
bool result = true;
state.applyOnAliases(value, [&](Value alias) {
Expand Down
15 changes: 5 additions & 10 deletions mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,6 @@ static SmallVector<Value> flattenValues(ArrayRef<ValueRange> values) {
return result;
}

/// Assert that the given value range contains a single value and return it.
static Value getSingleValue(ValueRange values) {
assert(values.size() == 1 && "expected single value");
return values.front();
}

// CRTP
// A base class that takes care of 1:N type conversion, which maps the converted
// op results (computed by the derived class) and materializes 1:N conversion.
Expand Down Expand Up @@ -119,9 +113,9 @@ class ConvertForOpTypes
// We can not do clone as the number of result types after conversion
// might be different.
ForOp newOp = rewriter.create<ForOp>(
op.getLoc(), getSingleValue(adaptor.getLowerBound()),
getSingleValue(adaptor.getUpperBound()),
getSingleValue(adaptor.getStep()),
op.getLoc(), llvm::getSingleElement(adaptor.getLowerBound()),
llvm::getSingleElement(adaptor.getUpperBound()),
llvm::getSingleElement(adaptor.getStep()),
flattenValues(adaptor.getInitArgs()));

// Reserve whatever attributes in the original op.
Expand Down Expand Up @@ -149,7 +143,8 @@ class ConvertIfOpTypes
TypeRange dstTypes) const {

IfOp newOp = rewriter.create<IfOp>(
op.getLoc(), dstTypes, getSingleValue(adaptor.getCondition()), true);
op.getLoc(), dstTypes, llvm::getSingleElement(adaptor.getCondition()),
true);
newOp->setAttrs(op->getAttrs());

// We do not need the empty blocks created by rewriter.
Expand Down
6 changes: 2 additions & 4 deletions mlir/lib/Dialect/SCF/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1310,10 +1310,8 @@ SmallVector<Loops, 8> mlir::tile(ArrayRef<scf::ForOp> forOps,
Loops mlir::tile(ArrayRef<scf::ForOp> forOps, ArrayRef<Value> sizes,
scf::ForOp target) {
SmallVector<scf::ForOp, 8> res;
for (auto loops : tile(forOps, sizes, ArrayRef<scf::ForOp>(target))) {
assert(loops.size() == 1);
res.push_back(loops[0]);
}
for (auto loops : tile(forOps, sizes, ArrayRef<scf::ForOp>(target)))
res.push_back(llvm::getSingleElement(loops));
return res;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ struct AssumingOpInterface
size_t resultNum = std::distance(op->getOpResults().begin(),
llvm::find(op->getOpResults(), value));
// TODO: Support multiple blocks.
assert(assumingOp.getDoRegion().getBlocks().size() == 1 &&
assert(llvm::hasSingleElement(assumingOp.getDoRegion().getBlocks()) &&
"expected exactly 1 block");
auto yieldOp = dyn_cast<shape::AssumingYieldOp>(
assumingOp.getDoRegion().front().getTerminator());
Expand All @@ -49,7 +49,7 @@ struct AssumingOpInterface
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationOptions &options) const {
auto assumingOp = cast<shape::AssumingOp>(op);
assert(assumingOp.getDoRegion().getBlocks().size() == 1 &&
assert(llvm::hasSingleElement(assumingOp.getDoRegion().getBlocks()) &&
"only 1 block supported");
auto yieldOp = cast<shape::AssumingYieldOp>(
assumingOp.getDoRegion().front().getTerminator());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,6 @@
using namespace mlir;
using namespace mlir::sparse_tensor;

/// Assert that the given value range contains a single value and return it.
static Value getSingleValue(ValueRange values) {
assert(values.size() == 1 && "expected single value");
return values.front();
}

static void convertLevelType(SparseTensorEncodingAttr enc, Level lvl,
SmallVectorImpl<Type> &fields) {
// Position and coordinate buffer in the sparse structure.
Expand Down Expand Up @@ -200,7 +194,7 @@ class ExtractIterSpaceConverter

// Construct the iteration space.
SparseIterationSpace space(loc, rewriter,
getSingleValue(adaptor.getTensor()), 0,
llvm::getSingleElement(adaptor.getTensor()), 0,
op.getLvlRange(), adaptor.getParentIter());

SmallVector<Value> result = space.toValues();
Expand All @@ -218,8 +212,8 @@ class ExtractValOpConverter : public OpConversionPattern<ExtractValOp> {
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value pos = adaptor.getIterator().back();
Value valBuf =
rewriter.create<ToValuesOp>(loc, getSingleValue(adaptor.getTensor()));
Value valBuf = rewriter.create<ToValuesOp>(
loc, llvm::getSingleElement(adaptor.getTensor()));
rewriter.replaceOpWithNewOp<memref::LoadOp>(op, valBuf, pos);
return success();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,6 @@ static SmallVector<Value> flattenValues(ArrayRef<ValueRange> values) {
return result;
}

/// Assert that the given value range contains a single value and return it.
static Value getSingleValue(ValueRange values) {
assert(values.size() == 1 && "expected single value");
return values.front();
}

/// Generates a load with proper `index` typing.
static Value genLoad(OpBuilder &builder, Location loc, Value mem, Value idx) {
idx = genCast(builder, loc, idx, builder.getIndexType());
Expand Down Expand Up @@ -962,10 +956,10 @@ class SparseCompressConverter : public OpConversionPattern<CompressOp> {
SmallVector<Value> fields;
auto desc = getMutDescriptorFromTensorTuple(adaptor.getTensor(), fields,
op.getTensor().getType());
Value values = getSingleValue(adaptor.getValues());
Value filled = getSingleValue(adaptor.getFilled());
Value added = getSingleValue(adaptor.getAdded());
Value count = getSingleValue(adaptor.getCount());
Value values = llvm::getSingleElement(adaptor.getValues());
Value filled = llvm::getSingleElement(adaptor.getFilled());
Value added = llvm::getSingleElement(adaptor.getAdded());
Value count = llvm::getSingleElement(adaptor.getCount());
const SparseTensorType dstType(desc.getRankedTensorType());
Type eltType = dstType.getElementType();

Expand Down Expand Up @@ -1041,7 +1035,7 @@ class SparseInsertConverter : public OpConversionPattern<tensor::InsertOp> {
SmallVector<Value> params = llvm::to_vector(desc.getFields());
SmallVector<Value> flatIndices = flattenValues(adaptor.getIndices());
params.append(flatIndices.begin(), flatIndices.end());
params.push_back(getSingleValue(adaptor.getScalar()));
params.push_back(llvm::getSingleElement(adaptor.getScalar()));
SparseInsertGenerator insertGen(op.getDest().getType(), flatSpTensorTps,
params, /*genCall=*/true);
SmallVector<Value> ret = insertGen.genCallOrInline(rewriter, loc);
Expand Down
5 changes: 2 additions & 3 deletions mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -521,9 +521,8 @@ static Value genTensorLoad(CodegenEnv &env, OpBuilder &builder, ExprId exp) {
Value ptr = genSubscript(env, builder, t, args);
if (llvm::isa<TensorType>(ptr.getType())) {
assert(env.options().sparseEmitStrategy ==
SparseEmitStrategy::kSparseIterator &&
args.size() == 1);
return builder.create<ExtractValOp>(loc, ptr, args.front());
SparseEmitStrategy::kSparseIterator);
return builder.create<ExtractValOp>(loc, ptr, llvm::getSingleElement(args));
}
return builder.create<memref::LoadOp>(loc, ptr, args);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1106,9 +1106,7 @@ Value FilterIterator::genShouldFilter(OpBuilder &b, Location l) {
Value notLegit = genCrdNotLegitPredicate(b, l, wrapCrd);
return {notLegit};
});

assert(r.size() == 1);
return r.front();
return llvm::getSingleElement(r);
}

Value FilterIterator::genNotEndImpl(OpBuilder &b, Location l) {
Expand All @@ -1120,8 +1118,7 @@ Value FilterIterator::genNotEndImpl(OpBuilder &b, Location l) {
// crd < size
return {CMPI(ult, crd, size)};
});
assert(r.size() == 1);
return r.front();
return llvm::getSingleElement(r);
}

ValueRange FilterIterator::forwardImpl(OpBuilder &b, Location l) {
Expand All @@ -1145,7 +1142,6 @@ ValueRange FilterIterator::forwardImpl(OpBuilder &b, Location l) {
/*beforeBuilder=*/
[this](OpBuilder &b, Location l, ValueRange ivs) {
ValueRange isFirst = linkNewScope(ivs);
assert(isFirst.size() == 1);
scf::ValueVector cont =
genWhenInBound(b, l, *wrap, C_FALSE,
[this, isFirst](OpBuilder &b, Location l,
Expand All @@ -1155,7 +1151,7 @@ ValueRange FilterIterator::forwardImpl(OpBuilder &b, Location l) {
genCrdNotLegitPredicate(b, l, wrapCrd);
Value crd = fromWrapCrd(b, l, wrapCrd);
Value ret = ANDI(CMPI(ult, crd, size), notLegit);
ret = ORI(ret, isFirst.front());
ret = ORI(ret, llvm::getSingleElement(isFirst));
return {ret};
});
b.create<scf::ConditionOp>(l, cont.front(), ivs);
Expand Down Expand Up @@ -1200,8 +1196,7 @@ Value SubSectIterHelper::genNotEnd(OpBuilder &b, Location l) {
// crd < size
return {CMPI(ult, crd, subSect.subSectSz)};
});
assert(r.size() == 1);
return r.front();
return llvm::getSingleElement(r);
}

Value SubSectIterHelper::deref(OpBuilder &b, Location l) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -833,8 +833,7 @@ makeReductionGen(omp::DeclareReductionOp decl, llvm::IRBuilderBase &builder,
moduleTranslation, &phis)))
return llvm::createStringError(
"failed to inline `combiner` region of `omp.declare_reduction`");
assert(phis.size() == 1);
result = phis[0];
result = llvm::getSingleElement(phis);
return builder.saveIP();
};
return gen;
Expand Down
3 changes: 1 addition & 2 deletions mlir/lib/Tools/PDLL/CodeGen/MLIRGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -388,8 +388,7 @@ Value CodeGen::genSingleExpr(const ast::Expr *expr) {
.Case<const ast::CallExpr, const ast::DeclRefExpr, const ast::TupleExpr>(
[&](auto derivedNode) {
SmallVector<Value> results = this->genExprImpl(derivedNode);
assert(results.size() == 1 && "expected single expression result");
return results[0];
return llvm::getSingleElement(results);
});
}

Expand Down
Loading
Loading