@@ -782,6 +782,64 @@ static LogicalResult printOperation(CppEmitter &emitter,
782
782
return success ();
783
783
}
784
784
785
+ static LogicalResult printOperation (CppEmitter &emitter,
786
+ mlir::UnrealizedConversionCastOp castOp) {
787
+ raw_ostream &os = emitter.ostream ();
788
+ Operation &op = *castOp.getOperation ();
789
+
790
+ if (castOp.getResults ().size () != 1 || castOp.getOperands ().size () != 1 ) {
791
+ return castOp.emitOpError (
792
+ " expected single result and single operand for conversion cast" );
793
+ }
794
+
795
+ Type destType = castOp.getResult (0 ).getType ();
796
+
797
+ auto srcPtrType =
798
+ mlir::dyn_cast<emitc::PointerType>(castOp.getOperand (0 ).getType ());
799
+ auto destArrayType = mlir::dyn_cast<emitc::ArrayType>(destType);
800
+
801
+ if (srcPtrType && destArrayType) {
802
+
803
+ // Emit declaration: (*v13)[dims] =
804
+ if (failed (emitter.emitType (op.getLoc (), destArrayType.getElementType ())))
805
+ return failure ();
806
+ os << " (*" << emitter.getOrCreateName (op.getResult (0 )) << " )" ;
807
+ for (int64_t dim : destArrayType.getShape ())
808
+ os << " [" << dim << " ]" ;
809
+ os << " = " ;
810
+
811
+ os << " (" ;
812
+
813
+ // Emit the C++ type for "datatype (*)[dim1][dim2]..."
814
+ if (failed (emitter.emitType (op.getLoc (), destArrayType.getElementType ())))
815
+ return failure ();
816
+
817
+ os << " (*)" ; // Pointer to array
818
+
819
+ for (int64_t dim : destArrayType.getShape ()) {
820
+ os << " [" << dim << " ]" ;
821
+ }
822
+ os << " )" ;
823
+ if (failed (emitter.emitOperand (castOp.getOperand (0 ))))
824
+ return failure ();
825
+
826
+ return success ();
827
+ }
828
+
829
+ // Fallback to generic C-style cast for other cases
830
+ if (failed (emitter.emitAssignPrefix (op)))
831
+ return failure ();
832
+
833
+ os << " (" ;
834
+ if (failed (emitter.emitType (op.getLoc (), destType)))
835
+ return failure ();
836
+ os << " )" ;
837
+ if (failed (emitter.emitOperand (castOp.getOperand (0 ))))
838
+ return failure ();
839
+
840
+ return success ();
841
+ }
842
+
785
843
static LogicalResult printOperation (CppEmitter &emitter,
786
844
emitc::ApplyOp applyOp) {
787
845
raw_ostream &os = emitter.ostream ();
@@ -1291,7 +1349,29 @@ CppEmitter::CppEmitter(raw_ostream &os, bool declareVariablesAtTop,
1291
1349
std::string CppEmitter::getSubscriptName (emitc::SubscriptOp op) {
1292
1350
std::string out;
1293
1351
llvm::raw_string_ostream ss (out);
1294
- ss << getOrCreateName (op.getValue ());
1352
+ Value baseValue = op.getValue ();
1353
+
1354
+ // Check if the baseValue (%arg1) is a result of UnrealizedConversionCastOp
1355
+ // that converts a pointer to an array type.
1356
+ if (auto castOp = dyn_cast_or_null<mlir::UnrealizedConversionCastOp>(
1357
+ baseValue.getDefiningOp ())) {
1358
+ auto destArrayType =
1359
+ mlir::dyn_cast<emitc::ArrayType>(castOp.getResult (0 ).getType ());
1360
+ auto srcPtrType =
1361
+ mlir::dyn_cast<emitc::PointerType>(castOp.getOperand (0 ).getType ());
1362
+
1363
+ // If it's a pointer being cast to an array, emit (*varName)
1364
+ if (srcPtrType && destArrayType) {
1365
+ ss << " (*" << getOrCreateName (baseValue) << " )" ;
1366
+ } else {
1367
+ // Fallback if the cast is not our specific pointer-to-array case
1368
+ ss << getOrCreateName (baseValue);
1369
+ }
1370
+ } else {
1371
+ // Default behavior for a regular array or other base types
1372
+ ss << getOrCreateName (baseValue);
1373
+ }
1374
+
1295
1375
for (auto index : op.getIndices ()) {
1296
1376
ss << " [" << getOrCreateName (index) << " ]" ;
1297
1377
}
@@ -1747,6 +1827,8 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
1747
1827
cacheDeferredOpResult (op.getResult (), getSubscriptName (op));
1748
1828
return success ();
1749
1829
})
1830
+ .Case <mlir::UnrealizedConversionCastOp>(
1831
+ [&](auto op) { return printOperation (*this , op); })
1750
1832
.Default ([&](Operation *) {
1751
1833
return op.emitOpError (" unable to find printer for op" );
1752
1834
});
0 commit comments