Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 22 additions & 2 deletions mlir/include/mlir/Transforms/DialectConversion.h
Original file line number Diff line number Diff line change
Expand Up @@ -795,12 +795,32 @@ class ConversionPatternRewriter final : public PatternRewriter {
/// patterns even if a failure is encountered during the rewrite step.
bool canRecoverFromRewriteFailure() const override { return true; }

/// PatternRewriter hook for replacing an operation.
/// Replace the given operation with the new values. The number of op results
/// and replacement values must match. The types may differ: the dialect
/// conversion driver will reconcile any surviving type mismatches at the end
/// of the conversion process with source materializations. The given
/// operation is erased.
void replaceOp(Operation *op, ValueRange newValues) override;

/// PatternRewriter hook for replacing an operation.
/// Replace the given operation with the results of the new op. The number of
/// op results must match. The types may differ: the dialect conversion
/// driver will reconcile any surviving type mismatches at the end of the
/// conversion process with source materializations. The original operation
/// is erased.
void replaceOp(Operation *op, Operation *newOp) override;

/// Replace the given operation with the new value ranges. The number of op
/// results and value ranges must match. If an original SSA value is replaced
/// by multiple SSA values (i.e., a value range has more than 1 element), the
/// conversion driver will insert an argument materialization to convert the
/// N SSA values back into 1 SSA value of the original type. The given
/// operation is erased.
///
/// Note: The argument materialization is a workaround until we have full 1:N
/// support in the dialect conversion. (It is going to disappear from both
/// `replaceOpWithMultiple` and `applySignatureConversion`.)
void replaceOpWithMultiple(Operation *op, ArrayRef<ValueRange> newValues);

/// PatternRewriter hook for erasing a dead operation. The uses of this
/// operation *must* be made dead by the end of the conversion process,
/// otherwise an assert will be issued.
Expand Down
40 changes: 12 additions & 28 deletions mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -141,47 +141,31 @@ struct DecomposeCallGraphTypesForCallOp : public OpConversionPattern<CallOp> {
getTypeConverter()));
}

// Create the new result types for the new `CallOp` and track the indices in
// the new call op's results that correspond to the old call op's results.
//
// expandedResultIndices[i] = "list of new result indices that old result i
// expanded to".
// Create the new result types for the new `CallOp` and track the number of
// replacement types for each original op result.
SmallVector<Type, 2> newResultTypes;
SmallVector<SmallVector<unsigned, 2>, 4> expandedResultIndices;
SmallVector<unsigned> expandedResultSizes;
for (Type resultType : op.getResultTypes()) {
unsigned oldSize = newResultTypes.size();
if (failed(typeConverter->convertType(resultType, newResultTypes)))
return failure();
auto &resultMapping = expandedResultIndices.emplace_back();
for (unsigned i = oldSize, e = newResultTypes.size(); i < e; i++)
resultMapping.push_back(i);
expandedResultSizes.push_back(newResultTypes.size() - oldSize);
}

CallOp newCallOp = rewriter.create<CallOp>(op.getLoc(), op.getCalleeAttr(),
newResultTypes, newOperands);

// Build a replacement value for each result to replace its uses. If a
// result has multiple mapping values, it needs to be materialized as a
// single value.
SmallVector<Value, 2> replacedValues;
// Build a replacement value for each result to replace its uses.
SmallVector<ValueRange> replacedValues;
replacedValues.reserve(op.getNumResults());
unsigned startIdx = 0;
for (unsigned i = 0, e = op.getNumResults(); i < e; ++i) {
auto decomposedValues = llvm::to_vector<6>(
llvm::map_range(expandedResultIndices[i],
[&](unsigned i) { return newCallOp.getResult(i); }));
if (decomposedValues.empty()) {
// No replacement is required.
replacedValues.push_back(nullptr);
} else if (decomposedValues.size() == 1) {
replacedValues.push_back(decomposedValues.front());
} else {
// Materialize a single Value to replace the original Value.
Value materialized = getTypeConverter()->materializeArgumentConversion(
rewriter, op.getLoc(), op.getType(i), decomposedValues);
replacedValues.push_back(materialized);
}
ValueRange repl =
newCallOp.getResults().slice(startIdx, expandedResultSizes[i]);
replacedValues.push_back(repl);
startIdx += expandedResultSizes[i];
}
rewriter.replaceOp(op, replacedValues);
rewriter.replaceOpWithMultiple(op, replacedValues);
return success();
}
};
Expand Down
43 changes: 20 additions & 23 deletions mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -600,8 +600,8 @@ class SparseCallConverter : public OpConversionPattern<func::CallOp> {
flattenOperands(adaptor.getOperands(), flattened);
auto newCall = rewriter.create<func::CallOp>(loc, op.getCallee(),
finalRetTy, flattened);
// (2) Create cast operation for sparse tensor returns.
SmallVector<Value> castedRet;
// (2) Gather sparse tensor returns.
SmallVector<SmallVector<Value>> packedResultVals;
// Tracks the offset of current return value (of the original call)
// relative to the new call (after sparse tensor flattening);
unsigned retOffset = 0;
Expand All @@ -618,21 +618,22 @@ class SparseCallConverter : public OpConversionPattern<func::CallOp> {
assert(!sparseFlat.empty());
if (sparseFlat.size() > 1) {
auto flatSize = sparseFlat.size();
ValueRange fields(iterator_range<ResultRange::iterator>(
newCall.result_begin() + retOffset,
newCall.result_begin() + retOffset + flatSize));
castedRet.push_back(genTuple(rewriter, loc, retType, fields));
packedResultVals.emplace_back();
llvm::append_range(packedResultVals.back(),
newCall.getResults().slice(retOffset, flatSize));
retOffset += flatSize;
} else {
// If this is an 1:1 conversion, no need for casting.
castedRet.push_back(newCall.getResult(retOffset));
packedResultVals.emplace_back();
packedResultVals.back().push_back(newCall.getResult(retOffset));
retOffset++;
}
sparseFlat.clear();
}

assert(castedRet.size() == op.getNumResults());
rewriter.replaceOp(op, castedRet);
assert(packedResultVals.size() == op.getNumResults());
rewriter.replaceOpWithMultiple(
op, llvm::to_vector_of<ValueRange>(packedResultVals));
return success();
}
};
Expand Down Expand Up @@ -776,7 +777,7 @@ class SparseTensorAllocConverter
// Reuses specifier.
fields.push_back(desc.getSpecifier());
assert(fields.size() == desc.getNumFields());
rewriter.replaceOp(op, genTuple(rewriter, loc, resType, fields));
rewriter.replaceOpWithMultiple(op, {fields});
return success();
}

Expand All @@ -796,7 +797,7 @@ class SparseTensorAllocConverter
sizeHint, lvlSizesValues, fields);

// Replace operation with resulting memrefs.
rewriter.replaceOp(op, genTuple(rewriter, loc, resType, fields));
rewriter.replaceOpWithMultiple(op, {fields});
return success();
}

Expand Down Expand Up @@ -837,7 +838,7 @@ class SparseTensorEmptyConverter : public OpConversionPattern<tensor::EmptyOp> {
sizeHint, lvlSizesValues, fields);

// Replace operation with resulting memrefs.
rewriter.replaceOp(op, genTuple(rewriter, loc, resType, fields));
rewriter.replaceOpWithMultiple(op, {fields});
return success();
}

Expand Down Expand Up @@ -893,7 +894,7 @@ class SparseTensorLoadConverter : public OpConversionPattern<LoadOp> {
if (op.getHasInserts())
genEndInsert(rewriter, op.getLoc(), desc);
// Replace operation with resulting memrefs.
rewriter.replaceOp(op, genTuple(rewriter, op.getLoc(), desc));
rewriter.replaceOpWithMultiple(op, {desc.getFields()});
return success();
}
};
Expand Down Expand Up @@ -1006,15 +1007,14 @@ class SparseCompressConverter : public OpConversionPattern<CompressOp> {
rewriter.create<scf::YieldOp>(loc, insertRet);

rewriter.setInsertionPointAfter(loop);
Value result = genTuple(rewriter, loc, dstType, loop->getResults());
// Deallocate the buffers on exit of the full loop nest.
Operation *parent = getTop(op);
rewriter.setInsertionPointAfter(parent);
rewriter.create<memref::DeallocOp>(loc, values);
rewriter.create<memref::DeallocOp>(loc, filled);
rewriter.create<memref::DeallocOp>(loc, added);
// Replace operation with resulting memrefs.
rewriter.replaceOp(op, result);
rewriter.replaceOpWithMultiple(op, {loop->getResults()});
return success();
}
};
Expand All @@ -1041,8 +1041,7 @@ class SparseInsertConverter : public OpConversionPattern<tensor::InsertOp> {
params, /*genCall=*/true);
SmallVector<Value> ret = insertGen.genCallOrInline(rewriter, loc);
// Replace operation with resulting memrefs.
rewriter.replaceOp(op,
genTuple(rewriter, loc, op.getDest().getType(), ret));
rewriter.replaceOpWithMultiple(op, {ret});
return success();
}
};
Expand Down Expand Up @@ -1215,8 +1214,7 @@ class SparseConvertConverter : public OpConversionPattern<ConvertOp> {
return true;
});

rewriter.replaceOp(
op, genTuple(rewriter, loc, op.getResult().getType(), fields));
rewriter.replaceOpWithMultiple(op, {fields});
return success();
}
};
Expand Down Expand Up @@ -1271,8 +1269,7 @@ class SparseExtractSliceConverter
// NOTE: we can not generate tuples directly from descriptor here, as the
// descriptor is holding the original type, yet we want the slice type
// here (they shared every memref but with an updated specifier).
rewriter.replaceOp(op, genTuple(rewriter, loc, op.getResult().getType(),
desc.getFields()));
rewriter.replaceOpWithMultiple(op, {desc.getFields()});
return success();
}
};
Expand Down Expand Up @@ -1403,7 +1400,7 @@ struct SparseAssembleOpConverter : public OpConversionPattern<AssembleOp> {
}
desc.setValMemSize(rewriter, loc, memSize);

rewriter.replaceOp(op, genTuple(rewriter, loc, desc));
rewriter.replaceOpWithMultiple(op, {desc.getFields()});
return success();
}
};
Expand Down Expand Up @@ -1577,7 +1574,7 @@ struct SparseNewConverter : public OpConversionPattern<NewOp> {
EmitCInterface::Off);

// Replace operation with resulting memrefs.
rewriter.replaceOp(op, genTuple(rewriter, loc, dstTp, fields));
rewriter.replaceOpWithMultiple(op, {fields});
return success();
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,19 +54,24 @@ convertSparseTensorType(RankedTensorType rtp, SmallVectorImpl<Type> &fields) {
// The sparse tensor type converter (defined in Passes.h).
//===----------------------------------------------------------------------===//

static Value materializeTuple(OpBuilder &builder, RankedTensorType tp,
ValueRange inputs, Location loc) {
if (!getSparseTensorEncoding(tp))
// Not a sparse tensor.
return Value();
// Sparsifier knows how to cancel out these casts.
return genTuple(builder, loc, tp, inputs);
}

SparseTensorTypeToBufferConverter::SparseTensorTypeToBufferConverter() {
addConversion([](Type type) { return type; });
addConversion(convertSparseTensorType);

// Required by scf.for 1:N type conversion.
addSourceMaterialization([](OpBuilder &builder, RankedTensorType tp,
ValueRange inputs, Location loc) -> Value {
if (!getSparseTensorEncoding(tp))
// Not a sparse tensor.
return Value();
// Sparsifier knows how to cancel out these casts.
return genTuple(builder, loc, tp, inputs);
});
addSourceMaterialization(materializeTuple);

// Required as a workaround until we have full 1:N support.
addArgumentMaterialization(materializeTuple);
}

//===----------------------------------------------------------------------===//
Expand Down
Loading
Loading