Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
167 changes: 98 additions & 69 deletions mlir/lib/Target/Cpp/TranslateToCpp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,18 @@ struct CppEmitter {
return lowestPrecedence();
return emittedExpressionPrecedence.back();
}

/// Emits attribute to the specified stream or returns failure.
LogicalResult emitAttributeToStream(Location loc, Attribute attr,
raw_ostream &ss);

/// Emits type 'type' to the specified stream or returns failure.
LogicalResult emitTypeToStream(Location loc, Type type, raw_ostream &ss);

/// Emits array of types as a std::tuple of the emitted types independently of
/// the array size to the specified stream.
LogicalResult emitTupleTypeToStream(Location loc, ArrayRef<Type> types,
raw_ostream &ss);
};
} // namespace

Expand Down Expand Up @@ -1319,10 +1331,28 @@ void CppEmitter::cacheDeferredOpResult(Value value, StringRef str) {
/// Return the existing or a new name for a Value.
StringRef CppEmitter::getOrCreateName(Value val) {
if (!valueMapper.count(val)) {
assert(!hasDeferredEmission(val.getDefiningOp()) &&
"cacheDeferredOpResult should have been called on this value, "
"update the emitOperation function.");
valueMapper.insert(val, formatv("v{0}", ++valueInScopeCount.top()));
if (auto constant = dyn_cast_if_present<ConstantOp>(val.getDefiningOp());
constant && !shouldUseConstantsAsVariables()) {
std::string constantValueString;
llvm::raw_string_ostream ss(constantValueString);

ss << "(";
bool success =
succeeded(emitTypeToStream(val.getLoc(), constant.getType(), ss));
assert(success);
ss << ") ";

success = succeeded(
emitAttributeToStream(val.getLoc(), constant.getValue(), ss));
assert(success);

valueMapper.insert(val, constantValueString);
} else {
assert(!hasDeferredEmission(val.getDefiningOp()) &&
"cacheDeferredOpResult should have been called on this value, "
"update the emitOperation function.");
valueMapper.insert(val, formatv("v{0}", ++valueInScopeCount.top()));
}
}
return *valueMapper.begin(val);
}
Expand Down Expand Up @@ -1353,16 +1383,21 @@ bool CppEmitter::hasBlockLabel(Block &block) {
}

LogicalResult CppEmitter::emitAttribute(Location loc, Attribute attr) {
return CppEmitter::emitAttributeToStream(loc, attr, os);
}

LogicalResult CppEmitter::emitAttributeToStream(Location loc, Attribute attr,
raw_ostream &ss) {
auto printInt = [&](const APInt &val, bool isUnsigned) {
if (val.getBitWidth() == 1) {
if (val.getBoolValue())
os << "true";
ss << "true";
else
os << "false";
ss << "false";
} else {
SmallString<128> strValue;
val.toString(strValue, 10, !isUnsigned, false);
os << strValue;
ss << strValue;
}
};

Expand All @@ -1371,28 +1406,28 @@ LogicalResult CppEmitter::emitAttribute(Location loc, Attribute attr) {
SmallString<128> strValue;
// Use default values of toString except don't truncate zeros.
val.toString(strValue, 0, 0, false);
os << strValue;
ss << strValue;
switch (llvm::APFloatBase::SemanticsToEnum(val.getSemantics())) {
case llvm::APFloatBase::S_IEEEhalf:
os << "f16";
ss << "f16";
break;
case llvm::APFloatBase::S_BFloat:
os << "bf16";
ss << "bf16";
break;
case llvm::APFloatBase::S_IEEEsingle:
os << "f";
ss << "f";
break;
case llvm::APFloatBase::S_IEEEdouble:
break;
default:
llvm_unreachable("unsupported floating point type");
};
} else if (val.isNaN()) {
os << "NAN";
ss << "NAN";
} else if (val.isInfinity()) {
if (val.isNegative())
os << "-";
os << "INFINITY";
ss << "-";
ss << "INFINITY";
}
};

Expand All @@ -1412,9 +1447,9 @@ LogicalResult CppEmitter::emitAttribute(Location loc, Attribute attr) {
return emitError(
loc, "expected floating point attribute to be f16, bf16, f32 or f64");
}
os << '{';
interleaveComma(dense, os, [&](const APFloat &val) { printFloat(val); });
os << '}';
ss << '{';
interleaveComma(dense, ss, [&](const APFloat &val) { printFloat(val); });
ss << '}';
return success();
}

Expand All @@ -1432,34 +1467,34 @@ LogicalResult CppEmitter::emitAttribute(Location loc, Attribute attr) {
if (auto dense = dyn_cast<DenseIntElementsAttr>(attr)) {
if (auto iType = dyn_cast<IntegerType>(
cast<TensorType>(dense.getType()).getElementType())) {
os << '{';
interleaveComma(dense, os, [&](const APInt &val) {
ss << '{';
interleaveComma(dense, ss, [&](const APInt &val) {
printInt(val, shouldMapToUnsigned(iType.getSignedness()));
});
os << '}';
ss << '}';
return success();
}
if (auto iType = dyn_cast<IndexType>(
cast<TensorType>(dense.getType()).getElementType())) {
os << '{';
interleaveComma(dense, os,
ss << '{';
interleaveComma(dense, ss,
[&](const APInt &val) { printInt(val, false); });
os << '}';
ss << '}';
return success();
}
}

// Print opaque attributes.
if (auto oAttr = dyn_cast<emitc::OpaqueAttr>(attr)) {
os << oAttr.getValue();
ss << oAttr.getValue();
return success();
}

// Print symbolic reference attributes.
if (auto sAttr = dyn_cast<SymbolRefAttr>(attr)) {
if (sAttr.getNestedReferences().size() > 1)
return emitError(loc, "attribute has more than 1 nested reference");
os << sAttr.getRootReference().getValue();
ss << sAttr.getRootReference().getValue();
return success();
}

Expand Down Expand Up @@ -1494,22 +1529,6 @@ LogicalResult CppEmitter::emitExpression(ExpressionOp expressionOp) {

LogicalResult CppEmitter::emitOperand(Value value) {
Operation *def = value.getDefiningOp();
if (!shouldUseConstantsAsVariables()) {
if (auto constant = dyn_cast_if_present<ConstantOp>(def)) {
os << "((";

if (failed(emitType(constant.getLoc(), constant.getType()))) {
return failure();
}
os << ") ";

if (failed(emitAttribute(constant.getLoc(), constant.getValue()))) {
return failure();
}
os << ")";
return success();
}
}

if (isPartOfCurrentExpression(value)) {
assert(def && "Expected operand to be defined by an operation");
Expand Down Expand Up @@ -1809,18 +1828,23 @@ LogicalResult CppEmitter::emitReferenceToType(Location loc, Type type) {
}

LogicalResult CppEmitter::emitType(Location loc, Type type) {
return emitTypeToStream(loc, type, os);
}

LogicalResult CppEmitter::emitTypeToStream(Location loc, Type type,
raw_ostream &ss) {
if (auto iType = dyn_cast<IntegerType>(type)) {
switch (iType.getWidth()) {
case 1:
return (os << "bool"), success();
return (ss << "bool"), success();
case 8:
case 16:
case 32:
case 64:
if (shouldMapToUnsigned(iType.getSignedness()))
return (os << "uint" << iType.getWidth() << "_t"), success();
return (ss << "uint" << iType.getWidth() << "_t"), success();
else
return (os << "int" << iType.getWidth() << "_t"), success();
return (ss << "int" << iType.getWidth() << "_t"), success();
default:
return emitError(loc, "cannot emit integer type ") << type;
}
Expand All @@ -1829,48 +1853,48 @@ LogicalResult CppEmitter::emitType(Location loc, Type type) {
switch (fType.getWidth()) {
case 16: {
if (llvm::isa<Float16Type>(type))
return (os << "_Float16"), success();
return (ss << "_Float16"), success();
else if (llvm::isa<BFloat16Type>(type))
return (os << "__bf16"), success();
return (ss << "__bf16"), success();
else
return emitError(loc, "cannot emit float type ") << type;
}
case 32:
return (os << "float"), success();
return (ss << "float"), success();
case 64:
return (os << "double"), success();
return (ss << "double"), success();
default:
return emitError(loc, "cannot emit float type ") << type;
}
}
if (auto iType = dyn_cast<IndexType>(type))
return (os << "size_t"), success();
return (ss << "size_t"), success();
if (auto sType = dyn_cast<emitc::SizeTType>(type))
return (os << "size_t"), success();
return (ss << "size_t"), success();
if (auto sType = dyn_cast<emitc::SignedSizeTType>(type))
return (os << "ssize_t"), success();
return (ss << "ssize_t"), success();
if (auto pType = dyn_cast<emitc::PtrDiffTType>(type))
return (os << "ptrdiff_t"), success();
return (ss << "ptrdiff_t"), success();
if (auto tType = dyn_cast<TensorType>(type)) {
if (!tType.hasRank())
return emitError(loc, "cannot emit unranked tensor type");
if (!tType.hasStaticShape())
return emitError(loc, "cannot emit tensor type with non static shape");
os << "Tensor<";
ss << "Tensor<";
if (isa<ArrayType>(tType.getElementType()))
return emitError(loc, "cannot emit tensor of array type ") << type;
if (failed(emitType(loc, tType.getElementType())))
if (failed(emitTypeToStream(loc, tType.getElementType(), ss)))
return failure();
auto shape = tType.getShape();
for (auto dimSize : shape) {
os << ", ";
os << dimSize;
ss << ", ";
ss << dimSize;
}
os << ">";
ss << ">";
return success();
}
if (auto tType = dyn_cast<TupleType>(type))
return emitTupleType(loc, tType.getTypes());
return emitTupleTypeToStream(loc, tType.getTypes(), ss);
if (auto oType = dyn_cast<emitc::OpaqueType>(type)) {
FailureOr<SmallVector<ReplacementItem>> items = oType.parseFormatString();
if (failed(items))
Expand All @@ -1879,34 +1903,34 @@ LogicalResult CppEmitter::emitType(Location loc, Type type) {
auto fmtArg = oType.getFmtArgs().begin();
for (ReplacementItem &item : *items) {
if (auto *str = std::get_if<StringRef>(&item)) {
os << *str;
ss << *str;
} else {
if (failed(emitType(loc, *fmtArg++))) {
if (failed(emitTypeToStream(loc, *fmtArg++, ss))) {
return failure();
}
}
}

return success();

os << oType.getValue();
ss << oType.getValue();
return success();
}
if (auto aType = dyn_cast<emitc::ArrayType>(type)) {
if (failed(emitType(loc, aType.getElementType())))
if (failed(emitTypeToStream(loc, aType.getElementType(), ss)))
return failure();
for (auto dim : aType.getShape())
os << "[" << dim << "]";
ss << "[" << dim << "]";
return success();
}
if (auto lType = dyn_cast<emitc::LValueType>(type))
return emitType(loc, lType.getValueType());
return emitTypeToStream(loc, lType.getValueType(), ss);
if (auto pType = dyn_cast<emitc::PointerType>(type)) {
if (isa<ArrayType>(pType.getPointee()))
return emitError(loc, "cannot emit pointer to array type ") << type;
if (failed(emitType(loc, pType.getPointee())))
if (failed(emitTypeToStream(loc, pType.getPointee(), ss)))
return failure();
os << "*";
ss << "*";
return success();
}
return emitError(loc, "cannot emit type ") << type;
Expand All @@ -1925,14 +1949,19 @@ LogicalResult CppEmitter::emitTypes(Location loc, ArrayRef<Type> types) {
}

LogicalResult CppEmitter::emitTupleType(Location loc, ArrayRef<Type> types) {
return emitTupleTypeToStream(loc, types, os);
}
LogicalResult CppEmitter::emitTupleTypeToStream(Location loc,
ArrayRef<Type> types,
raw_ostream &ss) {
if (llvm::any_of(types, llvm::IsaPred<ArrayType>)) {
return emitError(loc, "cannot emit tuple of array type");
}
os << "std::tuple<";
ss << "std::tuple<";
if (failed(interleaveCommaWithError(
types, os, [&](Type type) { return emitType(loc, type); })))
types, ss, [&](Type type) { return emitType(loc, type); })))
return failure();
os << ">";
ss << ">";
return success();
}

Expand Down
Loading
Loading