@@ -61,6 +61,9 @@ void mlir::emitc::buildTerminatedBody(OpBuilder &builder, Location loc) {
6161bool mlir::emitc::isSupportedEmitCType (Type type) {
6262 if (llvm::isa<emitc::OpaqueType>(type))
6363 return true ;
64+ if (auto lType = llvm::dyn_cast<emitc::LValueType>(type))
65+ // lvalue types are only allowed in a few places.
66+ return false ;
6467 if (auto ptrType = llvm::dyn_cast<emitc::PointerType>(type))
6568 return isSupportedEmitCType (ptrType.getPointee ());
6669 if (auto arrayType = llvm::dyn_cast<emitc::ArrayType>(type)) {
@@ -140,6 +143,8 @@ static LogicalResult verifyInitializationAttribute(Operation *op,
140143 << " string attributes are not supported, use #emitc.opaque instead" ;
141144
142145 Type resultType = op->getResult (0 ).getType ();
146+ if (auto lType = dyn_cast<LValueType>(resultType))
147+ resultType = lType.getValue ();
143148 Type attrType = cast<TypedAttr>(value).getType ();
144149
145150 if (resultType != attrType)
@@ -188,9 +193,19 @@ LogicalResult ApplyOp::verify() {
188193 if (applicableOperatorStr != " &" && applicableOperatorStr != " *" )
189194 return emitOpError (" applicable operator is illegal" );
190195
191- Operation *op = getOperand ().getDefiningOp ();
192- if (op && dyn_cast<ConstantOp>(op))
193- return emitOpError (" cannot apply to constant" );
196+ Type operandType = getOperand ().getType ();
197+ Type resultType = getResult ().getType ();
198+ if (applicableOperatorStr == " &" ) {
199+ if (!llvm::isa<emitc::LValueType>(operandType))
200+ return emitOpError (" operand type must be an lvalue when applying `&`" );
201+ if (!llvm::isa<emitc::PointerType>(resultType))
202+ return emitOpError (" result type must be a pointer when applying `&`" );
203+ } else {
204+ if (!llvm::isa<emitc::PointerType>(operandType))
205+ return emitOpError (" operand type must be a pointer when applying `*`" );
206+ if (!llvm::isa<emitc::LValueType>(resultType))
207+ return emitOpError (" result type must be an lvalue when applying `*`" );
208+ }
194209
195210 return success ();
196211}
@@ -202,20 +217,18 @@ LogicalResult ApplyOp::verify() {
202217// / The assign op requires that the assigned value's type matches the
203218// / assigned-to variable type.
204219LogicalResult emitc::AssignOp::verify () {
205- Value variable = getVar ();
206- Operation *variableDef = variable.getDefiningOp ();
207- if (!variableDef ||
208- !llvm::isa<emitc::VariableOp, emitc::SubscriptOp>(variableDef))
209- return emitOpError () << " requires first operand (" << variable
210- << " ) to be a Variable or subscript" ;
211-
212- Value value = getValue ();
213- if (variable.getType () != value.getType ())
214- return emitOpError () << " requires value's type (" << value.getType ()
215- << " ) to match variable's type (" << variable.getType ()
216- << " )" ;
217- if (isa<ArrayType>(variable.getType ()))
218- return emitOpError () << " cannot assign to array type" ;
220+ TypedValue<emitc::LValueType> variable = getVar ();
221+
222+ if (!variable.getDefiningOp ())
223+ return emitOpError () << " cannot assign to block argument" ;
224+
225+ Type valueType = getValue ().getType ();
226+ Type variableType = variable.getType ().getValue ();
227+ if (variableType != valueType)
228+ return emitOpError () << " requires value's type (" << valueType
229+ << " ) to match variable's type (" << variableType
230+ << " )\n variable: " << variable
231+ << " \n value: " << getValue () << " \n " ;
219232 return success ();
220233}
221234
@@ -842,9 +855,10 @@ LogicalResult emitc::SubscriptOp::verify() {
842855 }
843856 // Check element type.
844857 Type elementType = arrayType.getElementType ();
845- if (elementType != getType ()) {
858+ Type resultType = getType ().getValue ();
859+ if (elementType != resultType) {
846860 return emitOpError () << " on array operand requires element type ("
847- << elementType << " ) and result type (" << getType ()
861+ << elementType << " ) and result type (" << resultType
848862 << " ) to match" ;
849863 }
850864 return success ();
@@ -868,9 +882,10 @@ LogicalResult emitc::SubscriptOp::verify() {
868882 }
869883 // Check pointee type.
870884 Type pointeeType = pointerType.getPointee ();
871- if (pointeeType != getType ()) {
885+ Type resultType = getType ().getValue ();
886+ if (pointeeType != resultType) {
872887 return emitOpError () << " on pointer operand requires pointee type ("
873- << pointeeType << " ) and result type (" << getType ()
888+ << pointeeType << " ) and result type (" << resultType
874889 << " ) to match" ;
875890 }
876891 return success ();
@@ -964,6 +979,25 @@ emitc::ArrayType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
964979 return emitc::ArrayType::get (*shape, elementType);
965980}
966981
982+ // ===----------------------------------------------------------------------===//
983+ // LValueType
984+ // ===----------------------------------------------------------------------===//
985+
986+ LogicalResult mlir::emitc::LValueType::verify (
987+ llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
988+ mlir::Type value) {
989+ // Check that the wrapped type is valid. This especially forbids nested lvalue
990+ // types.
991+ if (!isSupportedEmitCType (value))
992+ return emitError ()
993+ << " !emitc.lvalue must wrap supported emitc type, but got " << value;
994+
995+ if (llvm::isa<emitc::ArrayType>(value))
996+ return emitError () << " !emitc.lvalue cannot wrap !emitc.array type" ;
997+
998+ return success ();
999+ }
1000+
9671001// ===----------------------------------------------------------------------===//
9681002// OpaqueType
9691003// ===----------------------------------------------------------------------===//
@@ -981,6 +1015,18 @@ LogicalResult mlir::emitc::OpaqueType::verify(
9811015 return success ();
9821016}
9831017
1018+ // ===----------------------------------------------------------------------===//
1019+ // PointerType
1020+ // ===----------------------------------------------------------------------===//
1021+
1022+ LogicalResult mlir::emitc::PointerType::verify (
1023+ llvm::function_ref<mlir::InFlightDiagnostic()> emitError, Type value) {
1024+ if (llvm::isa<emitc::LValueType>(value))
1025+ return emitError () << " pointers to lvalues are not allowed" ;
1026+
1027+ return success ();
1028+ }
1029+
9841030// ===----------------------------------------------------------------------===//
9851031// GlobalOp
9861032// ===----------------------------------------------------------------------===//
@@ -1078,9 +1124,22 @@ GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
10781124 << getName () << " ' does not reference a valid emitc.global" ;
10791125
10801126 Type resultType = getResult ().getType ();
1081- if (global.getType () != resultType)
1082- return emitOpError (" result type " )
1083- << resultType << " does not match type " << global.getType ()
1127+ Type globalType = global.getType ();
1128+
1129+ // global has array type
1130+ if (llvm::isa<ArrayType>(globalType)) {
1131+ if (globalType != resultType)
1132+ return emitOpError (" on array type expects result type " )
1133+ << resultType << " to match type " << globalType
1134+ << " of the global @" << getName ();
1135+ return success ();
1136+ }
1137+
1138+ // global has non-array type
1139+ auto lvalueType = dyn_cast<LValueType>(resultType);
1140+ if (!lvalueType || lvalueType.getValue () != globalType)
1141+ return emitOpError (" on non-array type expects result inner type " )
1142+ << lvalueType.getValue () << " to match type " << globalType
10841143 << " of the global @" << getName ();
10851144 return success ();
10861145}
0 commit comments