Skip to content

Commit a226b5b

Browse files
Handled UnrealizedConversionCast for C code generation and validated tests
Signed-off-by: LekkalaSravya3 <[email protected]>
1 parent be2c6c1 commit a226b5b

File tree

2 files changed

+98
-1
lines changed

2 files changed

+98
-1
lines changed

mlir/lib/Target/Cpp/TranslateToCpp.cpp

Lines changed: 83 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -830,6 +830,64 @@ static LogicalResult printOperation(CppEmitter &emitter,
830830
return success();
831831
}
832832

833+
static LogicalResult printOperation(CppEmitter &emitter,
834+
mlir::UnrealizedConversionCastOp castOp) {
835+
raw_ostream &os = emitter.ostream();
836+
Operation &op = *castOp.getOperation();
837+
838+
if (castOp.getResults().size() != 1 || castOp.getOperands().size() != 1) {
839+
return castOp.emitOpError(
840+
"expected single result and single operand for conversion cast");
841+
}
842+
843+
Type destType = castOp.getResult(0).getType();
844+
845+
auto srcPtrType =
846+
mlir::dyn_cast<emitc::PointerType>(castOp.getOperand(0).getType());
847+
auto destArrayType = mlir::dyn_cast<emitc::ArrayType>(destType);
848+
849+
if (srcPtrType && destArrayType) {
850+
851+
// Emit declaration: (*v13)[dims] =
852+
if (failed(emitter.emitType(op.getLoc(), destArrayType.getElementType())))
853+
return failure();
854+
os << " (*" << emitter.getOrCreateName(op.getResult(0)) << ")";
855+
for (int64_t dim : destArrayType.getShape())
856+
os << "[" << dim << "]";
857+
os << " = ";
858+
859+
os << "(";
860+
861+
// Emit the C++ type for "datatype (*)[dim1][dim2]..."
862+
if (failed(emitter.emitType(op.getLoc(), destArrayType.getElementType())))
863+
return failure();
864+
865+
os << "(*)"; // Pointer to array
866+
867+
for (int64_t dim : destArrayType.getShape()) {
868+
os << "[" << dim << "]";
869+
}
870+
os << ")";
871+
if (failed(emitter.emitOperand(castOp.getOperand(0))))
872+
return failure();
873+
874+
return success();
875+
}
876+
877+
// Fallback to generic C-style cast for other cases
878+
if (failed(emitter.emitAssignPrefix(op)))
879+
return failure();
880+
881+
os << "(";
882+
if (failed(emitter.emitType(op.getLoc(), destType)))
883+
return failure();
884+
os << ")";
885+
if (failed(emitter.emitOperand(castOp.getOperand(0))))
886+
return failure();
887+
888+
return success();
889+
}
890+
833891
static LogicalResult printOperation(CppEmitter &emitter,
834892
emitc::ApplyOp applyOp) {
835893
raw_ostream &os = emitter.ostream();
@@ -1339,7 +1397,29 @@ CppEmitter::CppEmitter(raw_ostream &os, bool declareVariablesAtTop,
13391397
std::string CppEmitter::getSubscriptName(emitc::SubscriptOp op) {
13401398
std::string out;
13411399
llvm::raw_string_ostream ss(out);
1342-
ss << getOrCreateName(op.getValue());
1400+
Value baseValue = op.getValue();
1401+
1402+
// Check if the baseValue (%arg1) is a result of UnrealizedConversionCastOp
1403+
// that converts a pointer to an array type.
1404+
if (auto castOp = dyn_cast_or_null<mlir::UnrealizedConversionCastOp>(
1405+
baseValue.getDefiningOp())) {
1406+
auto destArrayType =
1407+
mlir::dyn_cast<emitc::ArrayType>(castOp.getResult(0).getType());
1408+
auto srcPtrType =
1409+
mlir::dyn_cast<emitc::PointerType>(castOp.getOperand(0).getType());
1410+
1411+
// If it's a pointer being cast to an array, emit (*varName)
1412+
if (srcPtrType && destArrayType) {
1413+
ss << "(*" << getOrCreateName(baseValue) << ")";
1414+
} else {
1415+
// Fallback if the cast is not our specific pointer-to-array case
1416+
ss << getOrCreateName(baseValue);
1417+
}
1418+
} else {
1419+
// Default behavior for a regular array or other base types
1420+
ss << getOrCreateName(baseValue);
1421+
}
1422+
13431423
for (auto index : op.getIndices()) {
13441424
ss << "[" << getOrCreateName(index) << "]";
13451425
}
@@ -1796,6 +1876,8 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
17961876
cacheDeferredOpResult(op.getResult(), getSubscriptName(op));
17971877
return success();
17981878
})
1879+
.Case<mlir::UnrealizedConversionCastOp>(
1880+
[&](auto op) { return printOperation(*this, op); })
17991881
.Default([&](Operation *) {
18001882
return op.emitOpError("unable to find printer for op");
18011883
});
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
// RUN: mlir-translate -mlir-to-cpp %s | FileCheck %s
2+
3+
// CHECK-LABEL: void builtin_cast
4+
func.func @builtin_cast(%arg0: !emitc.ptr<f32>){
5+
// CHECK : float (*v2)[1][3][4][4] = (float(*)[1][3][4][4])v1
6+
%1 = builtin.unrealized_conversion_cast %arg0 : !emitc.ptr<f32> to !emitc.array<1x3x4x4xf32>
7+
return
8+
}
9+
10+
// CHECK-LABEL: void builtin_cast_index
11+
func.func @builtin_cast_index(%arg0: !emitc.size_t){
12+
// CHECK : size_t v2 = (size_t)v1
13+
%1 = builtin.unrealized_conversion_cast %arg0 : !emitc.size_t to index
14+
return
15+
}

0 commit comments

Comments
 (0)