-
Notifications
You must be signed in to change notification settings - Fork 14.9k
[mlir][emitc] Fix creating pointer from constant array #162083
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
When creating a pointer from a constant emitc array, check if it is constant. If it is, create the pointer as opaque<"const {type}">>. Move out C type string creation logic from TranslateToCpp.cpp to getCTypeString in EmitC.cpp as a shared utility function.
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-emitc Author: Hendrik_Klug (Jimmy2027) ChangesWhen creating a pointer from a constant emitc array, check if it is constant. If it is, create the pointer as opaque<"const {type}">>. Move out C type string creation logic from TranslateToCpp.cpp to getCTypeString in EmitC.cpp as a shared utility function. Full diff: https://github.com/llvm/llvm-project/pull/162083.diff 5 Files Affected:
diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h
index eb7ddeb3bfc54..614895977588a 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h
@@ -27,6 +27,7 @@
#include "mlir/Dialect/EmitC/IR/EmitCDialect.h.inc"
#include "mlir/Dialect/EmitC/IR/EmitCEnums.h.inc"
+#include <string>
#include <variant>
namespace mlir {
@@ -49,6 +50,10 @@ bool isSupportedFloatType(mlir::Type type);
/// Determines whether \p type is a emitc.size_t/ssize_t type.
bool isPointerWideType(mlir::Type type);
+/// Convert an MLIR type to its C type string representation.
+/// Returns an empty string if the type cannot be represented as a C type.
+std::string getCTypeString(Type type);
+
// Either a literal string, or an placeholder for the fmtArgs.
struct Placeholder {};
using ReplacementItem = std::variant<StringRef, Placeholder>;
diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
index 2b7bdc9a7b7f8..7b05284818ecb 100644
--- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
@@ -19,6 +19,7 @@
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/SymbolTable.h"
#include "mlir/IR/TypeRange.h"
#include "mlir/IR/Value.h"
#include "mlir/Transforms/DialectConversion.h"
@@ -134,8 +135,26 @@ createPointerFromEmitcArray(Location loc, OpBuilder &builder,
llvm::SmallVector<mlir::Value> indices(arrayType.getRank(), zeroIndex);
emitc::SubscriptOp subPtr =
emitc::SubscriptOp::create(builder, loc, arrayValue, ValueRange(indices));
+
+ // Determine the pointer type
+ Type pointerElementType = arrayType.getElementType();
+
+ // Check if the array comes from a const global
+ if (auto getGlobalOp = arrayValue.getDefiningOp<emitc::GetGlobalOp>()) {
+ auto globalOp = SymbolTable::lookupNearestSymbolFrom<emitc::GlobalOp>(
+ getGlobalOp, getGlobalOp.getNameAttr());
+ if (globalOp && globalOp.getConstSpecifier()) {
+ // Create a const pointer type using opaque type
+ std::string cTypeString = emitc::getCTypeString(pointerElementType);
+ if (!cTypeString.empty()) {
+ pointerElementType = emitc::OpaqueType::get(builder.getContext(),
+ "const " + cTypeString);
+ }
+ }
+ }
+
emitc::ApplyOp ptr = emitc::ApplyOp::create(
- builder, loc, emitc::PointerType::get(arrayType.getElementType()),
+ builder, loc, emitc::PointerType::get(pointerElementType),
builder.getStringAttr("&"), subPtr);
return ptr;
diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
index 5c8564bca6f86..d07993fb5a986 100644
--- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
+++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
@@ -139,6 +139,43 @@ bool mlir::emitc::isFundamentalType(Type type) {
isa<emitc::PointerType>(type);
}
+std::string mlir::emitc::getCTypeString(Type type) {
+ if (auto intType = dyn_cast<IntegerType>(type)) {
+ switch (intType.getWidth()) {
+ case 1:
+ return "bool";
+ case 8:
+ return intType.isUnsigned() ? "uint8_t" : "int8_t";
+ case 16:
+ return intType.isUnsigned() ? "uint16_t" : "int16_t";
+ case 32:
+ return intType.isUnsigned() ? "uint32_t" : "int32_t";
+ case 64:
+ return intType.isUnsigned() ? "uint64_t" : "int64_t";
+ default:
+ return "";
+ }
+ }
+ if (auto floatType = dyn_cast<FloatType>(type)) {
+ if (floatType.getWidth() == 16) {
+ if (isa<Float16Type>(type))
+ return "_Float16";
+ if (isa<BFloat16Type>(type))
+ return "__bf16";
+ return "";
+ }
+ if (floatType.getWidth() == 32)
+ return "float";
+ if (floatType.getWidth() == 64)
+ return "double";
+ return "";
+ }
+ if (auto opaqueType = dyn_cast<emitc::OpaqueType>(type))
+ return opaqueType.getValue().str();
+
+ return "";
+}
+
/// Check that the type of the initial value is compatible with the operations
/// result type.
static LogicalResult verifyInitializationAttribute(Operation *op,
diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
index a5bd80e9d6b8b..16db8e1aaa12d 100644
--- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp
+++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
@@ -1792,40 +1792,15 @@ LogicalResult CppEmitter::emitVariableDeclaration(Location loc, Type type,
}
LogicalResult CppEmitter::emitType(Location loc, Type type) {
- if (auto iType = dyn_cast<IntegerType>(type)) {
- switch (iType.getWidth()) {
- case 1:
- return (os << "bool"), success();
- case 8:
- case 16:
- case 32:
- case 64:
- if (shouldMapToUnsigned(iType.getSignedness()))
- return (os << "uint" << iType.getWidth() << "_t"), success();
- else
- return (os << "int" << iType.getWidth() << "_t"), success();
- default:
- return emitError(loc, "cannot emit integer type ") << type;
- }
- }
- if (auto fType = dyn_cast<FloatType>(type)) {
- switch (fType.getWidth()) {
- case 16: {
- if (llvm::isa<Float16Type>(type))
- return (os << "_Float16"), success();
- if (llvm::isa<BFloat16Type>(type))
- return (os << "__bf16"), success();
- else
- return emitError(loc, "cannot emit float type ") << type;
- }
- case 32:
- return (os << "float"), success();
- case 64:
- return (os << "double"), success();
- default:
- return emitError(loc, "cannot emit float type ") << type;
- }
- }
+ std::string cTypeString = emitc::getCTypeString(type);
+ if (!cTypeString.empty())
+ return (os << cTypeString), success();
+
+ // Handle integer and float cases that failed above
+ if (isa<IntegerType>(type))
+ return emitError(loc, "cannot emit integer type ") << type;
+ if (isa<FloatType>(type))
+ return emitError(loc, "cannot emit float type ") << type;
if (auto iType = dyn_cast<IndexType>(type))
return (os << "size_t"), success();
if (auto sType = dyn_cast<emitc::SizeTType>(type))
@@ -1854,10 +1829,6 @@ LogicalResult CppEmitter::emitType(Location loc, Type type) {
}
if (auto tType = dyn_cast<TupleType>(type))
return emitTupleType(loc, tType.getTypes());
- if (auto oType = dyn_cast<emitc::OpaqueType>(type)) {
- os << oType.getValue();
- return success();
- }
if (auto aType = dyn_cast<emitc::ArrayType>(type)) {
if (failed(emitType(loc, aType.getElementType())))
return failure();
diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir
index 2b4eda37903d4..97c06639bf35b 100644
--- a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir
+++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir
@@ -58,3 +58,25 @@ module @globals {
return
}
}
+
+// -----
+
+// CHECK-LABEL: const_global_copy
+module @const_global_copy {
+ memref.global "private" constant @const_data : memref<4xi8> = dense<[1, 2, 3, 4]>
+ // CHECK: emitc.global static const @const_data : !emitc.array<4xi8> = dense<[1, 2, 3, 4]>
+
+ func.func @copy_from_const_global() {
+ // CHECK: get_global @const_data : !emitc.array<4xi8>
+ %0 = memref.get_global @const_data : memref<4xi8>
+ // CHECK: "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.array<4xi8>
+ %1 = memref.alloca() : memref<4xi8>
+
+ // Verify that pointer from const global has const qualifier
+ // CHECK: apply "&"({{.*}}) : (!emitc.lvalue<i8>) -> !emitc.ptr<!emitc.opaque<"const int8_t">>
+ // CHECK: apply "&"({{.*}}) : (!emitc.lvalue<i8>) -> !emitc.ptr<i8>
+ // CHECK: call_opaque "memcpy"({{.*}}, {{.*}}, {{.*}}) : (!emitc.ptr<i8>, !emitc.ptr<!emitc.opaque<"const int8_t">>, !emitc.size_t) -> ()
+ memref.copy %0, %1 : memref<4xi8> to memref<4xi8>
+ return
+ }
+}
|
Hi @simon-camp, could you have a look at this? |
Hi @Jimmy2027 sorry for the delay. I'm not sure about this change here, as in mlir the result of get global is not constant compared to the C version now. So I'm Not Sure if WE Run into Problems with Users of the get_global. Is casting away the constness of the pointer a viable alternative. I think writing to this pointer will be UB in C, but that is also what's Said in the documentation If memrefs.global. |
Thanks for reviewing this @simon-camp ! I think I see what you mean, downstream ops might expect an |
But you can't correctly use pointer to const int right? In mlir you can pass the result of get_global from a constant global in a function call. In C you wouldn't be allowed to pass this to a function expecting a ptr to int/int array. So I think you will get compile errors in the generated code dependent in the users of the get_global with your changes right now. But maybe I'm wrong, I'd have to test this first. |
This doesn't seem like an emitc-specific issue (i.e. any dialect wishing to model that property in their type system would have a problem doing local type conversions) so this might be worth a wider discussion, e.g. are memref's memory spaces the proper way to carry this information? should memrefs support the constant property directly? At least in the short term I think the simplest solution is indeed to cast the constness away when lowering In the long run, adding a |
But wouldn't that be a user mistake in any case? C++ doesn't even compile a non-const pointer of a const... I can't think of any meaningful example where this change breaks something that is not already broken but maybe I'm biased and missing something :) |
|
If I understand correctly at the memref level, it is expected to model the constness with side effects and ops (memref.global constant). I'm not sure if something like a const attr to the memref type breaks its philosophy...
Yeah maybe this is the safest way forward for this PR, though I still can't think of any meaningful example where casting away the constness is actually needed...
yup I agree, that would be the nicest solution. The case I'm currently working with is actually obvious since the memref comes from a const global |
When creating a pointer from a constant emitc array, check if it is constant. If it is, create the pointer as opaque<"const {type}">>.
Move out C type string creation logic from TranslateToCpp.cpp to getCTypeString in EmitC.cpp as a shared utility function.