Skip to content

Commit a1a16c9

Browse files
committed
[mlir][emitc] Emit casting away constness when taking address of const global
- Modify the C++ emitter to detect when an AddressOf op traces back to a const global. If it does, emit a C-style cast (T*)(&...) to strip the const qualification - Adapt mlir/test/Target/Cpp/global.mlir to check for correct syntax of the generated code
1 parent 8dee997 commit a1a16c9

File tree

2 files changed

+174
-6
lines changed

2 files changed

+174
-6
lines changed

mlir/lib/Target/Cpp/TranslateToCpp.cpp

Lines changed: 63 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -397,6 +397,51 @@ static bool shouldBeInlined(ExpressionOp expressionOp) {
397397
return false;
398398
}
399399

400+
/// Helper function to check if a value traces back to a const global.
401+
/// Handles direct GetGlobalOp and GetGlobalOp through one or more SubscriptOps.
402+
/// Returns the GlobalOp if found and it has const_specifier, nullptr otherwise.
403+
static emitc::GlobalOp getConstGlobal(Value value, Operation *fromOp) {
404+
while (auto subscriptOp = value.getDefiningOp<emitc::SubscriptOp>()) {
405+
value = subscriptOp.getValue();
406+
}
407+
408+
auto getGlobalOp = value.getDefiningOp<emitc::GetGlobalOp>();
409+
if (!getGlobalOp)
410+
return nullptr;
411+
412+
// Find the nearest symbol table to check whether the global is const.
413+
Operation *symbolTableOp = fromOp;
414+
while (symbolTableOp && !symbolTableOp->hasTrait<OpTrait::SymbolTable>()) {
415+
symbolTableOp = symbolTableOp->getParentOp();
416+
}
417+
418+
if (!symbolTableOp)
419+
return nullptr;
420+
421+
SymbolTable symbolTable(symbolTableOp);
422+
auto globalOp = symbolTable.lookup<emitc::GlobalOp>(getGlobalOp.getName());
423+
424+
if (globalOp && globalOp.getConstSpecifier())
425+
return globalOp;
426+
427+
return nullptr;
428+
}
429+
430+
/// Emit address-of with a cast to strip const qualification.
431+
/// Produces: (ResultType)(&operand)
432+
static LogicalResult emitAddressOfWithConstCast(CppEmitter &emitter,
433+
Operation &op, Value operand) {
434+
raw_ostream &os = emitter.ostream();
435+
os << "(";
436+
if (failed(emitter.emitType(op.getLoc(), op.getResult(0).getType())))
437+
return failure();
438+
os << ")(&";
439+
if (failed(emitter.emitOperand(operand)))
440+
return failure();
441+
os << ")";
442+
return success();
443+
}
444+
400445
static LogicalResult printOperation(CppEmitter &emitter,
401446
emitc::DereferenceOp dereferenceOp) {
402447
std::string out;
@@ -496,8 +541,15 @@ static LogicalResult printOperation(CppEmitter &emitter,
496541

497542
if (failed(emitter.emitAssignPrefix(op)))
498543
return failure();
544+
545+
Value operand = addressOfOp.getReference();
546+
547+
// Check if we're taking address of a const global.
548+
if (getConstGlobal(operand, &op))
549+
return emitAddressOfWithConstCast(emitter, op, operand);
550+
499551
os << "&";
500-
return emitter.emitOperand(addressOfOp.getReference());
552+
return emitter.emitOperand(operand);
501553
}
502554

503555
static LogicalResult printOperation(CppEmitter &emitter,
@@ -903,8 +955,16 @@ static LogicalResult printOperation(CppEmitter &emitter,
903955

904956
if (failed(emitter.emitAssignPrefix(op)))
905957
return failure();
906-
os << applyOp.getApplicableOperator();
907-
return emitter.emitOperand(applyOp.getOperand());
958+
959+
StringRef applicableOperator = applyOp.getApplicableOperator();
960+
Value operand = applyOp.getOperand();
961+
962+
// Check if we're taking address of a const global.
963+
if (applicableOperator == "&" && getConstGlobal(operand, &op))
964+
return emitAddressOfWithConstCast(emitter, op, operand);
965+
966+
os << applicableOperator;
967+
return emitter.emitOperand(operand);
908968
}
909969

910970
static LogicalResult printOperation(CppEmitter &emitter,

mlir/test/Target/Cpp/global.mlir

Lines changed: 111 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,11 @@
11
// RUN: mlir-translate -mlir-to-cpp %s | FileCheck %s -check-prefix=CPP-DEFAULT
22
// RUN: mlir-translate -mlir-to-cpp -declare-variables-at-top %s | FileCheck %s -check-prefix=CPP-DECLTOP
3+
// check that the generated code is syntactically valid
4+
// RUN: mlir-translate -mlir-to-cpp %s | %host_cxx -fsyntax-only -x c++ -
5+
6+
emitc.include "stdint.h"
7+
emitc.include "limits.h"
8+
emitc.include "stddef.h"
39

410
emitc.global extern @decl : i8
511
// CPP-DEFAULT: extern int8_t decl;
@@ -21,6 +27,10 @@ emitc.global const @myconstant : !emitc.array<2xi16> = dense<2>
2127
// CPP-DEFAULT: const int16_t myconstant[2] = {2, 2};
2228
// CPP-DECLTOP: const int16_t myconstant[2] = {2, 2};
2329

30+
emitc.global const @myconstant_2d : !emitc.array<2x3xi16> = dense<1>
31+
// CPP-DEFAULT: const int16_t myconstant_2d[2][3] = {1, 1, 1, 1, 1, 1};
32+
// CPP-DECLTOP: const int16_t myconstant_2d[2][3] = {1, 1, 1, 1, 1, 1};
33+
2434
emitc.global extern const @extern_constant : !emitc.array<2xi16>
2535
// CPP-DEFAULT: extern const int16_t extern_constant[2];
2636
// CPP-DECLTOP: extern const int16_t extern_constant[2];
@@ -29,9 +39,9 @@ emitc.global static @static_var : f32
2939
// CPP-DEFAULT: static float static_var;
3040
// CPP-DECLTOP: static float static_var;
3141

32-
emitc.global static @static_const : f32 = 3.0
33-
// CPP-DEFAULT: static float static_const = 3.000000000e+00f;
34-
// CPP-DECLTOP: static float static_const = 3.000000000e+00f;
42+
emitc.global static const @static_const : f32 = 3.0
43+
// CPP-DEFAULT: static const float static_const = 3.000000000e+00f;
44+
// CPP-DECLTOP: static const float static_const = 3.000000000e+00f;
3545

3646
emitc.global @opaque_init : !emitc.opaque<"char"> = #emitc.opaque<"CHAR_MIN">
3747
// CPP-DEFAULT: char opaque_init = CHAR_MIN;
@@ -98,3 +108,101 @@ func.func @use_global_array_write(%i: index, %val : f32) {
98108
// CPP-DECLTOP-SAME: (size_t [[V1:.*]], float [[V2:.*]])
99109
// CPP-DECLTOP-NEXT: myglobal[[[V1]]] = [[V2]];
100110
// CPP-DECLTOP-NEXT: return;
111+
112+
func.func @use_const_global_array_pointer(%i: index) -> !emitc.ptr<i16> {
113+
%0 = emitc.get_global @myconstant : !emitc.array<2xi16>
114+
%1 = emitc.subscript %0[%i] : (!emitc.array<2xi16>, index) -> !emitc.lvalue<i16>
115+
%2 = emitc.apply "&"(%1) : (!emitc.lvalue<i16>) -> !emitc.ptr<i16>
116+
return %2 : !emitc.ptr<i16>
117+
}
118+
// CPP-DEFAULT-LABEL: int16_t* use_const_global_array_pointer
119+
// CPP-DEFAULT-SAME: (size_t [[V1:.*]])
120+
// CPP-DEFAULT-NEXT: int16_t* [[V2:.*]] = (int16_t*)(&myconstant[[[V1]]]);
121+
// CPP-DEFAULT-NEXT: return [[V2]];
122+
123+
// CPP-DECLTOP-LABEL: int16_t* use_const_global_array_pointer
124+
// CPP-DECLTOP-SAME: (size_t [[V1:.*]])
125+
// CPP-DECLTOP-NEXT: int16_t* [[V2:.*]];
126+
// CPP-DECLTOP-NEXT: [[V2]] = (int16_t*)(&myconstant[[[V1]]]);
127+
// CPP-DECLTOP-NEXT: return [[V2]];
128+
129+
func.func @use_const_global_scalar_pointer() -> !emitc.ptr<f32> {
130+
%0 = emitc.get_global @static_const : !emitc.lvalue<f32>
131+
%1 = emitc.apply "&"(%0) : (!emitc.lvalue<f32>) -> !emitc.ptr<f32>
132+
return %1 : !emitc.ptr<f32>
133+
}
134+
// CPP-DEFAULT-LABEL: float* use_const_global_scalar_pointer()
135+
// CPP-DEFAULT-NEXT: float* [[V1:.*]] = (float*)(&static_const);
136+
// CPP-DEFAULT-NEXT: return [[V1]];
137+
138+
// CPP-DECLTOP-LABEL: float* use_const_global_scalar_pointer()
139+
// CPP-DECLTOP-NEXT: float* [[V1:.*]];
140+
// CPP-DECLTOP-NEXT: [[V1]] = (float*)(&static_const);
141+
// CPP-DECLTOP-NEXT: return [[V1]];
142+
143+
func.func @use_const_global_2d_array_pointer(%i: index, %j: index) -> !emitc.ptr<i16> {
144+
%0 = emitc.get_global @myconstant_2d : !emitc.array<2x3xi16>
145+
%1 = emitc.subscript %0[%i, %j] : (!emitc.array<2x3xi16>, index, index) -> !emitc.lvalue<i16>
146+
%2 = emitc.apply "&"(%1) : (!emitc.lvalue<i16>) -> !emitc.ptr<i16>
147+
return %2 : !emitc.ptr<i16>
148+
}
149+
// CPP-DEFAULT-LABEL: int16_t* use_const_global_2d_array_pointer
150+
// CPP-DEFAULT-SAME: (size_t [[I:.*]], size_t [[J:.*]])
151+
// CPP-DEFAULT-NEXT: int16_t* [[V:.*]] = (int16_t*)(&myconstant_2d[[[I]]][[[J]]]);
152+
// CPP-DEFAULT-NEXT: return [[V]];
153+
154+
// CPP-DECLTOP-LABEL: int16_t* use_const_global_2d_array_pointer
155+
// CPP-DECLTOP-SAME: (size_t [[I:.*]], size_t [[J:.*]])
156+
// CPP-DECLTOP-NEXT: int16_t* [[V:.*]];
157+
// CPP-DECLTOP-NEXT: [[V]] = (int16_t*)(&myconstant_2d[[[I]]][[[J]]]);
158+
// CPP-DECLTOP-NEXT: return [[V]];
159+
160+
// Test emitc.address_of with const globals (same as emitc.apply "&" tests above)
161+
162+
func.func @use_const_global_array_pointer_address_of(%i: index) -> !emitc.ptr<i16> {
163+
%0 = emitc.get_global @myconstant : !emitc.array<2xi16>
164+
%1 = emitc.subscript %0[%i] : (!emitc.array<2xi16>, index) -> !emitc.lvalue<i16>
165+
%2 = emitc.address_of %1 : !emitc.lvalue<i16>
166+
return %2 : !emitc.ptr<i16>
167+
}
168+
// CPP-DEFAULT-LABEL: int16_t* use_const_global_array_pointer_address_of
169+
// CPP-DEFAULT-SAME: (size_t [[V1:.*]])
170+
// CPP-DEFAULT-NEXT: int16_t* [[V2:.*]] = (int16_t*)(&myconstant[[[V1]]]);
171+
// CPP-DEFAULT-NEXT: return [[V2]];
172+
173+
// CPP-DECLTOP-LABEL: int16_t* use_const_global_array_pointer_address_of
174+
// CPP-DECLTOP-SAME: (size_t [[V1:.*]])
175+
// CPP-DECLTOP-NEXT: int16_t* [[V2:.*]];
176+
// CPP-DECLTOP-NEXT: [[V2]] = (int16_t*)(&myconstant[[[V1]]]);
177+
// CPP-DECLTOP-NEXT: return [[V2]];
178+
179+
func.func @use_const_global_scalar_pointer_address_of() -> !emitc.ptr<f32> {
180+
%0 = emitc.get_global @static_const : !emitc.lvalue<f32>
181+
%1 = emitc.address_of %0 : !emitc.lvalue<f32>
182+
return %1 : !emitc.ptr<f32>
183+
}
184+
// CPP-DEFAULT-LABEL: float* use_const_global_scalar_pointer_address_of()
185+
// CPP-DEFAULT-NEXT: float* [[V1:.*]] = (float*)(&static_const);
186+
// CPP-DEFAULT-NEXT: return [[V1]];
187+
188+
// CPP-DECLTOP-LABEL: float* use_const_global_scalar_pointer_address_of()
189+
// CPP-DECLTOP-NEXT: float* [[V1:.*]];
190+
// CPP-DECLTOP-NEXT: [[V1]] = (float*)(&static_const);
191+
// CPP-DECLTOP-NEXT: return [[V1]];
192+
193+
func.func @use_const_global_2d_array_pointer_address_of(%i: index, %j: index) -> !emitc.ptr<i16> {
194+
%0 = emitc.get_global @myconstant_2d : !emitc.array<2x3xi16>
195+
%1 = emitc.subscript %0[%i, %j] : (!emitc.array<2x3xi16>, index, index) -> !emitc.lvalue<i16>
196+
%2 = emitc.address_of %1 : !emitc.lvalue<i16>
197+
return %2 : !emitc.ptr<i16>
198+
}
199+
// CPP-DEFAULT-LABEL: int16_t* use_const_global_2d_array_pointer_address_of
200+
// CPP-DEFAULT-SAME: (size_t [[I:.*]], size_t [[J:.*]])
201+
// CPP-DEFAULT-NEXT: int16_t* [[V:.*]] = (int16_t*)(&myconstant_2d[[[I]]][[[J]]]);
202+
// CPP-DEFAULT-NEXT: return [[V]];
203+
204+
// CPP-DECLTOP-LABEL: int16_t* use_const_global_2d_array_pointer_address_of
205+
// CPP-DECLTOP-SAME: (size_t [[I:.*]], size_t [[J:.*]])
206+
// CPP-DECLTOP-NEXT: int16_t* [[V:.*]];
207+
// CPP-DECLTOP-NEXT: [[V]] = (int16_t*)(&myconstant_2d[[[I]]][[[J]]]);
208+
// CPP-DECLTOP-NEXT: return [[V]];

0 commit comments

Comments
 (0)