Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 63 additions & 3 deletions mlir/lib/Target/Cpp/TranslateToCpp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,51 @@ static bool shouldBeInlined(ExpressionOp expressionOp) {
return false;
}

/// Helper function to check if a value traces back to a const global.
/// Handles direct GetGlobalOp and GetGlobalOp through one or more SubscriptOps.
/// Returns the GlobalOp if found and it has const_specifier, nullptr otherwise.
static emitc::GlobalOp getConstGlobal(Value value, Operation *fromOp) {
while (auto subscriptOp = value.getDefiningOp<emitc::SubscriptOp>()) {
value = subscriptOp.getValue();
}

auto getGlobalOp = value.getDefiningOp<emitc::GetGlobalOp>();
if (!getGlobalOp)
return nullptr;

// Find the nearest symbol table to check whether the global is const.
Operation *symbolTableOp = fromOp;
while (symbolTableOp && !symbolTableOp->hasTrait<OpTrait::SymbolTable>()) {
symbolTableOp = symbolTableOp->getParentOp();
}

if (!symbolTableOp)
return nullptr;

SymbolTable symbolTable(symbolTableOp);
auto globalOp = symbolTable.lookup<emitc::GlobalOp>(getGlobalOp.getName());

if (globalOp && globalOp.getConstSpecifier())
return globalOp;

return nullptr;
}

/// Emit address-of with a cast to strip const qualification.
/// Produces: (ResultType)(&operand)
static LogicalResult emitAddressOfWithConstCast(CppEmitter &emitter,
Operation &op, Value operand) {
raw_ostream &os = emitter.ostream();
os << "(";
if (failed(emitter.emitType(op.getLoc(), op.getResult(0).getType())))
return failure();
os << ")(&";
if (failed(emitter.emitOperand(operand)))
return failure();
os << ")";
return success();
}

static LogicalResult printOperation(CppEmitter &emitter,
emitc::DereferenceOp dereferenceOp) {
std::string out;
Expand Down Expand Up @@ -496,8 +541,15 @@ static LogicalResult printOperation(CppEmitter &emitter,

if (failed(emitter.emitAssignPrefix(op)))
return failure();

Value operand = addressOfOp.getReference();

// Check if we're taking address of a const global.
if (getConstGlobal(operand, &op))
return emitAddressOfWithConstCast(emitter, op, operand);

os << "&";
return emitter.emitOperand(addressOfOp.getReference());
return emitter.emitOperand(operand);
}

static LogicalResult printOperation(CppEmitter &emitter,
Expand Down Expand Up @@ -903,8 +955,16 @@ static LogicalResult printOperation(CppEmitter &emitter,

if (failed(emitter.emitAssignPrefix(op)))
return failure();
os << applyOp.getApplicableOperator();
return emitter.emitOperand(applyOp.getOperand());

StringRef applicableOperator = applyOp.getApplicableOperator();
Value operand = applyOp.getOperand();

// Check if we're taking address of a const global.
if (applicableOperator == "&" && getConstGlobal(operand, &op))
return emitAddressOfWithConstCast(emitter, op, operand);

os << applicableOperator;
return emitter.emitOperand(operand);
}

static LogicalResult printOperation(CppEmitter &emitter,
Expand Down
114 changes: 111 additions & 3 deletions mlir/test/Target/Cpp/global.mlir
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
// RUN: mlir-translate -mlir-to-cpp %s | FileCheck %s -check-prefix=CPP-DEFAULT
// RUN: mlir-translate -mlir-to-cpp -declare-variables-at-top %s | FileCheck %s -check-prefix=CPP-DECLTOP
// check that the generated code is syntactically valid
// RUN: mlir-translate -mlir-to-cpp %s | %host_cxx -fsyntax-only -x c++ -

emitc.include "stdint.h"
emitc.include "limits.h"
emitc.include "stddef.h"

emitc.global extern @decl : i8
// CPP-DEFAULT: extern int8_t decl;
Expand All @@ -21,6 +27,10 @@ emitc.global const @myconstant : !emitc.array<2xi16> = dense<2>
// CPP-DEFAULT: const int16_t myconstant[2] = {2, 2};
// CPP-DECLTOP: const int16_t myconstant[2] = {2, 2};

emitc.global const @myconstant_2d : !emitc.array<2x3xi16> = dense<1>
// CPP-DEFAULT: const int16_t myconstant_2d[2][3] = {1, 1, 1, 1, 1, 1};
// CPP-DECLTOP: const int16_t myconstant_2d[2][3] = {1, 1, 1, 1, 1, 1};

emitc.global extern const @extern_constant : !emitc.array<2xi16>
// CPP-DEFAULT: extern const int16_t extern_constant[2];
// CPP-DECLTOP: extern const int16_t extern_constant[2];
Expand All @@ -29,9 +39,9 @@ emitc.global static @static_var : f32
// CPP-DEFAULT: static float static_var;
// CPP-DECLTOP: static float static_var;

emitc.global static @static_const : f32 = 3.0
// CPP-DEFAULT: static float static_const = 3.000000000e+00f;
// CPP-DECLTOP: static float static_const = 3.000000000e+00f;
emitc.global static const @static_const : f32 = 3.0
// CPP-DEFAULT: static const float static_const = 3.000000000e+00f;
// CPP-DECLTOP: static const float static_const = 3.000000000e+00f;

emitc.global @opaque_init : !emitc.opaque<"char"> = #emitc.opaque<"CHAR_MIN">
// CPP-DEFAULT: char opaque_init = CHAR_MIN;
Expand Down Expand Up @@ -98,3 +108,101 @@ func.func @use_global_array_write(%i: index, %val : f32) {
// CPP-DECLTOP-SAME: (size_t [[V1:.*]], float [[V2:.*]])
// CPP-DECLTOP-NEXT: myglobal[[[V1]]] = [[V2]];
// CPP-DECLTOP-NEXT: return;

func.func @use_const_global_array_pointer(%i: index) -> !emitc.ptr<i16> {
%0 = emitc.get_global @myconstant : !emitc.array<2xi16>
%1 = emitc.subscript %0[%i] : (!emitc.array<2xi16>, index) -> !emitc.lvalue<i16>
%2 = emitc.apply "&"(%1) : (!emitc.lvalue<i16>) -> !emitc.ptr<i16>
return %2 : !emitc.ptr<i16>
}
// CPP-DEFAULT-LABEL: int16_t* use_const_global_array_pointer
// CPP-DEFAULT-SAME: (size_t [[V1:.*]])
// CPP-DEFAULT-NEXT: int16_t* [[V2:.*]] = (int16_t*)(&myconstant[[[V1]]]);
// CPP-DEFAULT-NEXT: return [[V2]];

// CPP-DECLTOP-LABEL: int16_t* use_const_global_array_pointer
// CPP-DECLTOP-SAME: (size_t [[V1:.*]])
// CPP-DECLTOP-NEXT: int16_t* [[V2:.*]];
// CPP-DECLTOP-NEXT: [[V2]] = (int16_t*)(&myconstant[[[V1]]]);
// CPP-DECLTOP-NEXT: return [[V2]];

func.func @use_const_global_scalar_pointer() -> !emitc.ptr<f32> {
%0 = emitc.get_global @static_const : !emitc.lvalue<f32>
%1 = emitc.apply "&"(%0) : (!emitc.lvalue<f32>) -> !emitc.ptr<f32>
return %1 : !emitc.ptr<f32>
}
// CPP-DEFAULT-LABEL: float* use_const_global_scalar_pointer()
// CPP-DEFAULT-NEXT: float* [[V1:.*]] = (float*)(&static_const);
// CPP-DEFAULT-NEXT: return [[V1]];

// CPP-DECLTOP-LABEL: float* use_const_global_scalar_pointer()
// CPP-DECLTOP-NEXT: float* [[V1:.*]];
// CPP-DECLTOP-NEXT: [[V1]] = (float*)(&static_const);
// CPP-DECLTOP-NEXT: return [[V1]];

func.func @use_const_global_2d_array_pointer(%i: index, %j: index) -> !emitc.ptr<i16> {
%0 = emitc.get_global @myconstant_2d : !emitc.array<2x3xi16>
%1 = emitc.subscript %0[%i, %j] : (!emitc.array<2x3xi16>, index, index) -> !emitc.lvalue<i16>
%2 = emitc.apply "&"(%1) : (!emitc.lvalue<i16>) -> !emitc.ptr<i16>
return %2 : !emitc.ptr<i16>
}
// CPP-DEFAULT-LABEL: int16_t* use_const_global_2d_array_pointer
// CPP-DEFAULT-SAME: (size_t [[I:.*]], size_t [[J:.*]])
// CPP-DEFAULT-NEXT: int16_t* [[V:.*]] = (int16_t*)(&myconstant_2d[[[I]]][[[J]]]);
// CPP-DEFAULT-NEXT: return [[V]];

// CPP-DECLTOP-LABEL: int16_t* use_const_global_2d_array_pointer
// CPP-DECLTOP-SAME: (size_t [[I:.*]], size_t [[J:.*]])
// CPP-DECLTOP-NEXT: int16_t* [[V:.*]];
// CPP-DECLTOP-NEXT: [[V]] = (int16_t*)(&myconstant_2d[[[I]]][[[J]]]);
// CPP-DECLTOP-NEXT: return [[V]];

// Test emitc.address_of with const globals (same as emitc.apply "&" tests above)

func.func @use_const_global_array_pointer_address_of(%i: index) -> !emitc.ptr<i16> {
%0 = emitc.get_global @myconstant : !emitc.array<2xi16>
%1 = emitc.subscript %0[%i] : (!emitc.array<2xi16>, index) -> !emitc.lvalue<i16>
%2 = emitc.address_of %1 : !emitc.lvalue<i16>
return %2 : !emitc.ptr<i16>
}
// CPP-DEFAULT-LABEL: int16_t* use_const_global_array_pointer_address_of
// CPP-DEFAULT-SAME: (size_t [[V1:.*]])
// CPP-DEFAULT-NEXT: int16_t* [[V2:.*]] = (int16_t*)(&myconstant[[[V1]]]);
// CPP-DEFAULT-NEXT: return [[V2]];

// CPP-DECLTOP-LABEL: int16_t* use_const_global_array_pointer_address_of
// CPP-DECLTOP-SAME: (size_t [[V1:.*]])
// CPP-DECLTOP-NEXT: int16_t* [[V2:.*]];
// CPP-DECLTOP-NEXT: [[V2]] = (int16_t*)(&myconstant[[[V1]]]);
// CPP-DECLTOP-NEXT: return [[V2]];

func.func @use_const_global_scalar_pointer_address_of() -> !emitc.ptr<f32> {
%0 = emitc.get_global @static_const : !emitc.lvalue<f32>
%1 = emitc.address_of %0 : !emitc.lvalue<f32>
return %1 : !emitc.ptr<f32>
}
// CPP-DEFAULT-LABEL: float* use_const_global_scalar_pointer_address_of()
// CPP-DEFAULT-NEXT: float* [[V1:.*]] = (float*)(&static_const);
// CPP-DEFAULT-NEXT: return [[V1]];

// CPP-DECLTOP-LABEL: float* use_const_global_scalar_pointer_address_of()
// CPP-DECLTOP-NEXT: float* [[V1:.*]];
// CPP-DECLTOP-NEXT: [[V1]] = (float*)(&static_const);
// CPP-DECLTOP-NEXT: return [[V1]];

func.func @use_const_global_2d_array_pointer_address_of(%i: index, %j: index) -> !emitc.ptr<i16> {
%0 = emitc.get_global @myconstant_2d : !emitc.array<2x3xi16>
%1 = emitc.subscript %0[%i, %j] : (!emitc.array<2x3xi16>, index, index) -> !emitc.lvalue<i16>
%2 = emitc.address_of %1 : !emitc.lvalue<i16>
return %2 : !emitc.ptr<i16>
}
// CPP-DEFAULT-LABEL: int16_t* use_const_global_2d_array_pointer_address_of
// CPP-DEFAULT-SAME: (size_t [[I:.*]], size_t [[J:.*]])
// CPP-DEFAULT-NEXT: int16_t* [[V:.*]] = (int16_t*)(&myconstant_2d[[[I]]][[[J]]]);
// CPP-DEFAULT-NEXT: return [[V]];

// CPP-DECLTOP-LABEL: int16_t* use_const_global_2d_array_pointer_address_of
// CPP-DECLTOP-SAME: (size_t [[I:.*]], size_t [[J:.*]])
// CPP-DECLTOP-NEXT: int16_t* [[V:.*]];
// CPP-DECLTOP-NEXT: [[V]] = (int16_t*)(&myconstant_2d[[[I]]][[[J]]]);
// CPP-DECLTOP-NEXT: return [[V]];
Loading