Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 51 additions & 34 deletions mlir/lib/Target/Cpp/TranslateToCpp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,9 @@ struct CppEmitter {
return lowestPrecedence();
return emittedExpressionPrecedence.back();
}

LogicalResult emitAttributeToArbitraryStream(Location loc, Attribute attr,
raw_ostream &ss);
};
} // namespace

Expand Down Expand Up @@ -1290,7 +1293,15 @@ std::string CppEmitter::getSubscriptName(emitc::SubscriptOp op) {
llvm::raw_string_ostream ss(out);
ss << getOrCreateName(op.getValue());
for (auto index : op.getIndices()) {
ss << "[" << getOrCreateName(index) << "]";
ss << "[";
if (auto constant = dyn_cast_if_present<ConstantOp>(index.getDefiningOp());
constant && !shouldUseConstantsAsVariables()) {
assert(llvm::succeeded(emitAttributeToArbitraryStream(
op->getLoc(), constant.getValue(), ss)));
} else {
ss << getOrCreateName(index);
}
ss << "]";
}
return out;
}
Expand Down Expand Up @@ -1353,16 +1364,22 @@ bool CppEmitter::hasBlockLabel(Block &block) {
}

LogicalResult CppEmitter::emitAttribute(Location loc, Attribute attr) {
return CppEmitter::emitAttributeToArbitraryStream(loc, attr, os);
}

LogicalResult CppEmitter::emitAttributeToArbitraryStream(Location loc,
Attribute attr,
raw_ostream &ss) {
auto printInt = [&](const APInt &val, bool isUnsigned) {
if (val.getBitWidth() == 1) {
if (val.getBoolValue())
os << "true";
ss << "true";
else
os << "false";
ss << "false";
} else {
SmallString<128> strValue;
val.toString(strValue, 10, !isUnsigned, false);
os << strValue;
ss << strValue;
}
};

Expand All @@ -1371,28 +1388,28 @@ LogicalResult CppEmitter::emitAttribute(Location loc, Attribute attr) {
SmallString<128> strValue;
// Use default values of toString except don't truncate zeros.
val.toString(strValue, 0, 0, false);
os << strValue;
ss << strValue;
switch (llvm::APFloatBase::SemanticsToEnum(val.getSemantics())) {
case llvm::APFloatBase::S_IEEEhalf:
os << "f16";
ss << "f16";
break;
case llvm::APFloatBase::S_BFloat:
os << "bf16";
ss << "bf16";
break;
case llvm::APFloatBase::S_IEEEsingle:
os << "f";
ss << "f";
break;
case llvm::APFloatBase::S_IEEEdouble:
break;
default:
llvm_unreachable("unsupported floating point type");
};
} else if (val.isNaN()) {
os << "NAN";
ss << "NAN";
} else if (val.isInfinity()) {
if (val.isNegative())
os << "-";
os << "INFINITY";
ss << "-";
ss << "INFINITY";
}
};

Expand All @@ -1412,9 +1429,9 @@ LogicalResult CppEmitter::emitAttribute(Location loc, Attribute attr) {
return emitError(
loc, "expected floating point attribute to be f16, bf16, f32 or f64");
}
os << '{';
interleaveComma(dense, os, [&](const APFloat &val) { printFloat(val); });
os << '}';
ss << '{';
interleaveComma(dense, ss, [&](const APFloat &val) { printFloat(val); });
ss << '}';
return success();
}

Expand All @@ -1432,34 +1449,34 @@ LogicalResult CppEmitter::emitAttribute(Location loc, Attribute attr) {
if (auto dense = dyn_cast<DenseIntElementsAttr>(attr)) {
if (auto iType = dyn_cast<IntegerType>(
cast<TensorType>(dense.getType()).getElementType())) {
os << '{';
interleaveComma(dense, os, [&](const APInt &val) {
ss << '{';
interleaveComma(dense, ss, [&](const APInt &val) {
printInt(val, shouldMapToUnsigned(iType.getSignedness()));
});
os << '}';
ss << '}';
return success();
}
if (auto iType = dyn_cast<IndexType>(
cast<TensorType>(dense.getType()).getElementType())) {
os << '{';
interleaveComma(dense, os,
ss << '{';
interleaveComma(dense, ss,
[&](const APInt &val) { printInt(val, false); });
os << '}';
ss << '}';
return success();
}
}

// Print opaque attributes.
if (auto oAttr = dyn_cast<emitc::OpaqueAttr>(attr)) {
os << oAttr.getValue();
ss << oAttr.getValue();
return success();
}

// Print symbolic reference attributes.
if (auto sAttr = dyn_cast<SymbolRefAttr>(attr)) {
if (sAttr.getNestedReferences().size() > 1)
return emitError(loc, "attribute has more than 1 nested reference");
os << sAttr.getRootReference().getValue();
ss << sAttr.getRootReference().getValue();
return success();
}

Expand Down Expand Up @@ -1494,21 +1511,21 @@ LogicalResult CppEmitter::emitExpression(ExpressionOp expressionOp) {

LogicalResult CppEmitter::emitOperand(Value value) {
Operation *def = value.getDefiningOp();
if (!shouldUseConstantsAsVariables()) {
if (auto constant = dyn_cast_if_present<ConstantOp>(def)) {
os << "((";

if (failed(emitType(constant.getLoc(), constant.getType()))) {
return failure();
}
os << ") ";
if (auto constant = dyn_cast_if_present<ConstantOp>(def);
constant && !shouldUseConstantsAsVariables()) {
os << "((";

if (failed(emitAttribute(constant.getLoc(), constant.getValue()))) {
return failure();
}
os << ")";
return success();
if (failed(emitType(constant.getLoc(), constant.getType()))) {
return failure();
}
os << ") ";

if (failed(emitAttribute(constant.getLoc(), constant.getValue()))) {
return failure();
}
os << ")";
return success();
}

if (isPartOfCurrentExpression(value)) {
Expand Down
12 changes: 11 additions & 1 deletion mlir/test/Target/Cpp/emitc-constants-as-variables.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,19 @@ func.func @test() {

return
}

// CPP-DEFAULT: void test() {
// CPP-DEFAULT-NEXT: for (size_t v1 = ((size_t) 0); v1 < ((size_t) 10); v1 += ((size_t) 1)) {
// CPP-DEFAULT-NEXT: }
// CPP-DEFAULT-NEXT: return;
// CPP-DEFAULT-NEXT: }

// -----

func.func @test_subscript(%arg0: !emitc.array<4xf32>) -> (!emitc.lvalue<f32>) {
%c0 = "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t
%0 = emitc.subscript %arg0[%c0] : (!emitc.array<4xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
return %0 : !emitc.lvalue<f32>
}
// CPP-DEFAULT: float test_subscript(float v1[4]) {
// CPP-DEFAULT-NEXT: return v1[0];
// CPP-DEFAULT-NEXT: }
Loading