Skip to content

Commit d5a9fe3

Browse files
Handled UnrealizedConversionCast for C code generation and validated tests
Signed-off-by: LekkalaSravya3 <[email protected]>
1 parent a7f5abb commit d5a9fe3

File tree

2 files changed

+98
-1
lines changed

2 files changed

+98
-1
lines changed

mlir/lib/Target/Cpp/TranslateToCpp.cpp

Lines changed: 83 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -782,6 +782,64 @@ static LogicalResult printOperation(CppEmitter &emitter,
782782
return success();
783783
}
784784

785+
static LogicalResult printOperation(CppEmitter &emitter,
786+
mlir::UnrealizedConversionCastOp castOp) {
787+
raw_ostream &os = emitter.ostream();
788+
Operation &op = *castOp.getOperation();
789+
790+
if (castOp.getResults().size() != 1 || castOp.getOperands().size() != 1) {
791+
return castOp.emitOpError(
792+
"expected single result and single operand for conversion cast");
793+
}
794+
795+
Type destType = castOp.getResult(0).getType();
796+
797+
auto srcPtrType =
798+
mlir::dyn_cast<emitc::PointerType>(castOp.getOperand(0).getType());
799+
auto destArrayType = mlir::dyn_cast<emitc::ArrayType>(destType);
800+
801+
if (srcPtrType && destArrayType) {
802+
803+
// Emit declaration: (*v13)[dims] =
804+
if (failed(emitter.emitType(op.getLoc(), destArrayType.getElementType())))
805+
return failure();
806+
os << " (*" << emitter.getOrCreateName(op.getResult(0)) << ")";
807+
for (int64_t dim : destArrayType.getShape())
808+
os << "[" << dim << "]";
809+
os << " = ";
810+
811+
os << "(";
812+
813+
// Emit the C++ type for "datatype (*)[dim1][dim2]..."
814+
if (failed(emitter.emitType(op.getLoc(), destArrayType.getElementType())))
815+
return failure();
816+
817+
os << "(*)"; // Pointer to array
818+
819+
for (int64_t dim : destArrayType.getShape()) {
820+
os << "[" << dim << "]";
821+
}
822+
os << ")";
823+
if (failed(emitter.emitOperand(castOp.getOperand(0))))
824+
return failure();
825+
826+
return success();
827+
}
828+
829+
// Fallback to generic C-style cast for other cases
830+
if (failed(emitter.emitAssignPrefix(op)))
831+
return failure();
832+
833+
os << "(";
834+
if (failed(emitter.emitType(op.getLoc(), destType)))
835+
return failure();
836+
os << ")";
837+
if (failed(emitter.emitOperand(castOp.getOperand(0))))
838+
return failure();
839+
840+
return success();
841+
}
842+
785843
static LogicalResult printOperation(CppEmitter &emitter,
786844
emitc::ApplyOp applyOp) {
787845
raw_ostream &os = emitter.ostream();
@@ -1291,7 +1349,29 @@ CppEmitter::CppEmitter(raw_ostream &os, bool declareVariablesAtTop,
12911349
std::string CppEmitter::getSubscriptName(emitc::SubscriptOp op) {
12921350
std::string out;
12931351
llvm::raw_string_ostream ss(out);
1294-
ss << getOrCreateName(op.getValue());
1352+
Value baseValue = op.getValue();
1353+
1354+
// Check if the baseValue (%arg1) is a result of UnrealizedConversionCastOp
1355+
// that converts a pointer to an array type.
1356+
if (auto castOp = dyn_cast_or_null<mlir::UnrealizedConversionCastOp>(
1357+
baseValue.getDefiningOp())) {
1358+
auto destArrayType =
1359+
mlir::dyn_cast<emitc::ArrayType>(castOp.getResult(0).getType());
1360+
auto srcPtrType =
1361+
mlir::dyn_cast<emitc::PointerType>(castOp.getOperand(0).getType());
1362+
1363+
// If it's a pointer being cast to an array, emit (*varName)
1364+
if (srcPtrType && destArrayType) {
1365+
ss << "(*" << getOrCreateName(baseValue) << ")";
1366+
} else {
1367+
// Fallback if the cast is not our specific pointer-to-array case
1368+
ss << getOrCreateName(baseValue);
1369+
}
1370+
} else {
1371+
// Default behavior for a regular array or other base types
1372+
ss << getOrCreateName(baseValue);
1373+
}
1374+
12951375
for (auto index : op.getIndices()) {
12961376
ss << "[" << getOrCreateName(index) << "]";
12971377
}
@@ -1747,6 +1827,8 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
17471827
cacheDeferredOpResult(op.getResult(), getSubscriptName(op));
17481828
return success();
17491829
})
1830+
.Case<mlir::UnrealizedConversionCastOp>(
1831+
[&](auto op) { return printOperation(*this, op); })
17501832
.Default([&](Operation *) {
17511833
return op.emitOpError("unable to find printer for op");
17521834
});
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
// RUN: mlir-translate -mlir-to-cpp %s | FileCheck %s
2+
3+
// CHECK-LABEL: void builtin_cast
4+
func.func @builtin_cast(%arg0: !emitc.ptr<f32>){
5+
// CHECK : float (*v2)[1][3][4][4] = (float(*)[1][3][4][4])v1
6+
%1 = builtin.unrealized_conversion_cast %arg0 : !emitc.ptr<f32> to !emitc.array<1x3x4x4xf32>
7+
return
8+
}
9+
10+
// CHECK-LABEL: void builtin_cast_index
11+
func.func @builtin_cast_index(%arg0: !emitc.size_t){
12+
// CHECK : size_t v2 = (size_t)v1
13+
%1 = builtin.unrealized_conversion_cast %arg0 : !emitc.size_t to index
14+
return
15+
}

0 commit comments

Comments
 (0)