diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td index 34312655115a1..8cbdf710cfa6e 100644 --- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td +++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td @@ -199,6 +199,41 @@ def OpenACC_DataClauseEnum : I64EnumAttr<"DataClause", def OpenACC_DataClauseAttr : EnumAttr; +// Data clause modifiers: +// * readonly: Added in OpenACC 2.7 to copyin and cache. +// * zero: Added in OpenACC 3.0 for create and copyout. +// * always, alwaysin, alwaysout: Added in OpenACC 3.4 for +// copy, copyin, and copyout clauses. +// * capture: Added in OpenACC 3.4 for copy, copyin, copyout and create clauses. +def OpenACC_DataClauseModifierNone : I32BitEnumAttrCaseNone<"none">; +// All of the modifiers below are bit flags - so the value noted is `1 << bit`. +// Thus the `zero` modifier is `1 << 0` = 1, `readonly` is `1 << 1` = 2, etc. +def OpenACC_DataClauseModifierZero : I32BitEnumAttrCaseBit<"zero", 0>; +def OpenACC_DataClauseModifierReadonly : I32BitEnumAttrCaseBit<"readonly", 1>; +def OpenACC_DataClauseModifierAlwaysIn : I32BitEnumAttrCaseBit<"alwaysin", 2>; +def OpenACC_DataClauseModifierAlwaysOut : I32BitEnumAttrCaseBit<"alwaysout", 3>; +def OpenACC_DataClauseModifierAlways : I32BitEnumAttrCaseGroup<"always", + [OpenACC_DataClauseModifierAlwaysIn, OpenACC_DataClauseModifierAlwaysOut]>; +def OpenACC_DataClauseModifierCapture : I32BitEnumAttrCaseBit<"capture", 4>; + +def OpenACC_DataClauseModifierEnum : I32BitEnumAttr< + "DataClauseModifier", + "Captures data clause modifiers", + [ + OpenACC_DataClauseModifierNone, OpenACC_DataClauseModifierZero, + OpenACC_DataClauseModifierReadonly, OpenACC_DataClauseModifierAlwaysIn, + OpenACC_DataClauseModifierAlwaysOut, OpenACC_DataClauseModifierAlways, + OpenACC_DataClauseModifierCapture]> { + let separator = ","; + let cppNamespace = "::mlir::acc"; + let genSpecializedAttr = 0; + let printBitEnumPrimaryGroups = 1; +} + +def OpenACC_DataClauseModifierAttr : EnumAttr; + class OpenACC_Attr traits = [], string baseCppClass = "::mlir::Attribute"> @@ -477,6 +512,8 @@ class OpenACC_DataEntryOp:$dataClause, DefaultValuedAttr:$structured, DefaultValuedAttr:$implicit, + DefaultValuedAttr:$modifiers, OptionalAttr:$name)); let description = !strconcat(extraDescription, [{ @@ -506,6 +543,7 @@ class OpenACC_DataEntryOp, OpBuilder<(ins "::mlir::Value":$var, "bool":$structured, "bool":$implicit, @@ -601,9 +640,23 @@ class OpenACC_DataEntryOp]; + }]>, + OpBuilder<(ins "::mlir::Type":$accVarType, "::mlir::Value":$var, + "::mlir::Type":$varType, "::mlir::Value":$varPtrPtr, + "::mlir::ValueRange":$bounds, + "::mlir::ValueRange":$asyncOperands, + "::mlir::ArrayAttr":$asyncOperandsDeviceType, + "::mlir::ArrayAttr":$asyncOnly, + "::mlir::acc::DataClause":$dataClause, "bool":$structured, + "bool":$implicit, "::mlir::StringAttr":$name), + [{ + build($_builder, $_state, accVarType, var, varType, varPtrPtr, bounds, + asyncOperands, asyncOperandsDeviceType, asyncOnly, dataClause, + structured, implicit, ::mlir::acc::DataClauseModifier::none, name); + }]>, + ]; } //===----------------------------------------------------------------------===// @@ -817,9 +870,7 @@ def OpenACC_CacheOp : OpenACC_DataEntryOp<"cache", let extraClassDeclaration = extraClassDeclarationBase # [{ /// Check if this is a cache with readonly modifier. - bool isCacheReadonly() { - return getDataClause() == acc::DataClause::acc_cache_readonly; - } + bool isCacheReadonly(); }]; } @@ -840,6 +891,8 @@ class OpenACC_DataExitOp:$dataClause, DefaultValuedAttr:$structured, DefaultValuedAttr:$implicit, + DefaultValuedAttr:$modifiers, OptionalAttr:$name)); let description = !strconcat(extraDescription, [{ @@ -861,6 +914,7 @@ class OpenACC_DataExitOp bounds, /*asyncOperands=*/{}, /*asyncOperandsDeviceType=*/nullptr, /*asyncOnly=*/nullptr, /*dataClause=*/nullptr, /*structured=*/$_builder.getBoolAttr(structured), - /*implicit=*/$_builder.getBoolAttr(implicit), /*name=*/nullptr); + /*implicit=*/$_builder.getBoolAttr(implicit), /*modifiers=*/nullptr, + /*name=*/nullptr); }]>, OpBuilder<(ins "::mlir::Value":$accVar, "::mlir::Value":$var, @@ -961,9 +1016,22 @@ class OpenACC_DataExitOpWithVarPtr bounds, /*asyncOperands=*/{}, /*asyncOperandsDeviceType=*/nullptr, /*asyncOnly=*/nullptr, /*dataClause=*/nullptr, /*structured=*/$_builder.getBoolAttr(structured), - /*implicit=*/$_builder.getBoolAttr(implicit), + /*implicit=*/$_builder.getBoolAttr(implicit), /*modifiers=*/nullptr, /*name=*/$_builder.getStringAttr(name)); - }]>]; + }]>, + OpBuilder<(ins "::mlir::Value":$accVar, "::mlir::Value":$var, + "::mlir::Type":$varType, "::mlir::ValueRange":$bounds, + "::mlir::ValueRange":$asyncOperands, + "::mlir::ArrayAttr":$asyncOperandsDeviceType, + "::mlir::ArrayAttr":$asyncOnly, + "::mlir::acc::DataClause":$dataClause, "bool":$structured, + "bool":$implicit, "::mlir::StringAttr":$name), + [{ + build($_builder, $_state, accVar, var, varType, bounds, + asyncOperands, asyncOperandsDeviceType, asyncOnly, dataClause, + structured, implicit, ::mlir::acc::DataClauseModifier::none, name); + }]>, + ]; code extraClassDeclarationDataExit = [{ mlir::TypedValue getVarPtr() { @@ -998,7 +1066,8 @@ class OpenACC_DataExitOpNoVarPtr : bounds, /*asyncOperands=*/{}, /*asyncOperandsDeviceType=*/nullptr, /*asyncOnly=*/nullptr, /*dataClause=*/nullptr, /*structured=*/$_builder.getBoolAttr(structured), - /*implicit=*/$_builder.getBoolAttr(implicit), /*name=*/nullptr); + /*implicit=*/$_builder.getBoolAttr(implicit), /*modifiers=*/nullptr, + /*name=*/nullptr); }]>, OpBuilder<(ins "::mlir::Value":$accVar, "bool":$structured, "bool":$implicit, @@ -1009,9 +1078,20 @@ class OpenACC_DataExitOpNoVarPtr : bounds, /*asyncOperands=*/{}, /*asyncOperandsDeviceType=*/nullptr, /*asyncOnly=*/nullptr, /*dataClause=*/nullptr, /*structured=*/$_builder.getBoolAttr(structured), - /*implicit=*/$_builder.getBoolAttr(implicit), + /*implicit=*/$_builder.getBoolAttr(implicit), /*modifiers=*/nullptr, /*name=*/$_builder.getStringAttr(name)); - }]> + }]>, + OpBuilder<(ins "::mlir::Value":$accVar, "::mlir::ValueRange":$bounds, + "::mlir::ValueRange":$asyncOperands, + "::mlir::ArrayAttr":$asyncOperandsDeviceType, + "::mlir::ArrayAttr":$asyncOnly, + "::mlir::acc::DataClause":$dataClause, "bool":$structured, + "bool":$implicit, "::mlir::StringAttr":$name), + [{ + build($_builder, $_state, accVar, bounds, asyncOperands, + asyncOperandsDeviceType, asyncOnly, dataClause, structured, + implicit, ::mlir::acc::DataClauseModifier::none, name); + }]>, ]; code extraClassDeclarationDataExit = [{ diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp index 0dfead98b7e73..e08fc263e29cc 100644 --- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp +++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp @@ -321,6 +321,24 @@ static LogicalResult checkVarAndAccVar(Op op) { return success(); } +template +static LogicalResult checkNoModifier(Op op) { + if (op.getModifiers() != acc::DataClauseModifier::none) + return op.emitError("no data clause modifiers are allowed"); + return success(); +} + +template +static LogicalResult +checkValidModifier(Op op, acc::DataClauseModifier validModifiers) { + if (acc::bitEnumContainsAny(op.getModifiers(), ~validModifiers)) + return op.emitError( + "invalid data clause modifiers: " + + acc::stringifyDataClauseModifier(op.getModifiers() & ~validModifiers)); + + return success(); +} + static ParseResult parseVar(mlir::OpAsmParser &parser, OpAsmParser::UnresolvedOperand &var) { // Either `var` or `varPtr` keyword is required. @@ -447,6 +465,8 @@ LogicalResult acc::PrivateOp::verify() { "data clause associated with private operation must match its intent"); if (failed(checkVarAndVarType(*this))) return failure(); + if (failed(checkNoModifier(*this))) + return failure(); return success(); } @@ -459,6 +479,8 @@ LogicalResult acc::FirstprivateOp::verify() { "match its intent"); if (failed(checkVarAndVarType(*this))) return failure(); + if (failed(checkNoModifier(*this))) + return failure(); return success(); } @@ -471,6 +493,8 @@ LogicalResult acc::ReductionOp::verify() { "match its intent"); if (failed(checkVarAndVarType(*this))) return failure(); + if (failed(checkNoModifier(*this))) + return failure(); return success(); } @@ -485,6 +509,8 @@ LogicalResult acc::DevicePtrOp::verify() { return failure(); if (failed(checkVarAndAccVar(*this))) return failure(); + if (failed(checkNoModifier(*this))) + return failure(); return success(); } @@ -499,6 +525,8 @@ LogicalResult acc::PresentOp::verify() { return failure(); if (failed(checkVarAndAccVar(*this))) return failure(); + if (failed(checkNoModifier(*this))) + return failure(); return success(); } @@ -518,11 +546,17 @@ LogicalResult acc::CopyinOp::verify() { return failure(); if (failed(checkVarAndAccVar(*this))) return failure(); + if (failed(checkValidModifier(*this, acc::DataClauseModifier::readonly | + acc::DataClauseModifier::always | + acc::DataClauseModifier::capture))) + return failure(); return success(); } bool acc::CopyinOp::isCopyinReadonly() { - return getDataClause() == acc::DataClause::acc_copyin_readonly; + return getDataClause() == acc::DataClause::acc_copyin_readonly || + acc::bitEnumContainsAny(getModifiers(), + acc::DataClauseModifier::readonly); } //===----------------------------------------------------------------------===// @@ -541,13 +575,18 @@ LogicalResult acc::CreateOp::verify() { return failure(); if (failed(checkVarAndAccVar(*this))) return failure(); + if (failed(checkValidModifier(*this, acc::DataClauseModifier::zero | + acc::DataClauseModifier::alwaysout | + acc::DataClauseModifier::capture))) + return failure(); return success(); } bool acc::CreateOp::isCreateZero() { // The zero modifier is encoded in the data clause. return getDataClause() == acc::DataClause::acc_create_zero || - getDataClause() == acc::DataClause::acc_copyout_zero; + getDataClause() == acc::DataClause::acc_copyout_zero || + acc::bitEnumContainsAny(getModifiers(), acc::DataClauseModifier::zero); } //===----------------------------------------------------------------------===// @@ -561,6 +600,8 @@ LogicalResult acc::NoCreateOp::verify() { return failure(); if (failed(checkVarAndAccVar(*this))) return failure(); + if (failed(checkNoModifier(*this))) + return failure(); return success(); } @@ -575,6 +616,8 @@ LogicalResult acc::AttachOp::verify() { return failure(); if (failed(checkVarAndAccVar(*this))) return failure(); + if (failed(checkNoModifier(*this))) + return failure(); return success(); } @@ -590,6 +633,8 @@ LogicalResult acc::DeclareDeviceResidentOp::verify() { return failure(); if (failed(checkVarAndAccVar(*this))) return failure(); + if (failed(checkNoModifier(*this))) + return failure(); return success(); } @@ -605,6 +650,8 @@ LogicalResult acc::DeclareLinkOp::verify() { return failure(); if (failed(checkVarAndAccVar(*this))) return failure(); + if (failed(checkNoModifier(*this))) + return failure(); return success(); } @@ -626,11 +673,16 @@ LogicalResult acc::CopyoutOp::verify() { return failure(); if (failed(checkVarAndAccVar(*this))) return failure(); + if (failed(checkValidModifier(*this, acc::DataClauseModifier::zero | + acc::DataClauseModifier::always | + acc::DataClauseModifier::capture))) + return failure(); return success(); } bool acc::CopyoutOp::isCopyoutZero() { - return getDataClause() == acc::DataClause::acc_copyout_zero; + return getDataClause() == acc::DataClause::acc_copyout_zero || + acc::bitEnumContainsAny(getModifiers(), acc::DataClauseModifier::zero); } //===----------------------------------------------------------------------===// @@ -652,6 +704,13 @@ LogicalResult acc::DeleteOp::verify() { " or specify original clause this operation was decomposed from"); if (!getAccVar()) return emitError("must have device pointer"); + // This op is the exit part of copyin and create - thus allow all modifiers + // allowed on either case. + if (failed(checkValidModifier(*this, acc::DataClauseModifier::zero | + acc::DataClauseModifier::readonly | + acc::DataClauseModifier::alwaysin | + acc::DataClauseModifier::capture))) + return failure(); return success(); } @@ -667,6 +726,8 @@ LogicalResult acc::DetachOp::verify() { " or specify original clause this operation was decomposed from"); if (!getAccVar()) return emitError("must have device pointer"); + if (failed(checkNoModifier(*this))) + return failure(); return success(); } @@ -686,6 +747,8 @@ LogicalResult acc::UpdateHostOp::verify() { return failure(); if (failed(checkVarAndAccVar(*this))) return failure(); + if (failed(checkNoModifier(*this))) + return failure(); return success(); } @@ -702,6 +765,8 @@ LogicalResult acc::UpdateDeviceOp::verify() { return failure(); if (failed(checkVarAndAccVar(*this))) return failure(); + if (failed(checkNoModifier(*this))) + return failure(); return success(); } @@ -718,6 +783,8 @@ LogicalResult acc::UseDeviceOp::verify() { return failure(); if (failed(checkVarAndAccVar(*this))) return failure(); + if (failed(checkNoModifier(*this))) + return failure(); return success(); } @@ -735,9 +802,17 @@ LogicalResult acc::CacheOp::verify() { return failure(); if (failed(checkVarAndAccVar(*this))) return failure(); + if (failed(checkValidModifier(*this, acc::DataClauseModifier::readonly))) + return failure(); return success(); } +bool acc::CacheOp::isCacheReadonly() { + return getDataClause() == acc::DataClause::acc_cache_readonly || + acc::bitEnumContainsAny(getModifiers(), + acc::DataClauseModifier::readonly); +} + template static ParseResult parseRegions(OpAsmParser &parser, OperationState &state, unsigned nRegions = 1) { diff --git a/mlir/test/Dialect/OpenACC/invalid.mlir b/mlir/test/Dialect/OpenACC/invalid.mlir index 8f6e961a06163..d85ad2ff80d80 100644 --- a/mlir/test/Dialect/OpenACC/invalid.mlir +++ b/mlir/test/Dialect/OpenACC/invalid.mlir @@ -819,3 +819,15 @@ func.func @acc_loop_container() { } attributes { collapse = [3], collapseDeviceType = [#acc.device_type], independent = [#acc.device_type]} return } + +// ----- + +%value = memref.alloc() : memref +// expected-error @below {{no data clause modifiers are allowed}} +%0 = acc.private varPtr(%value : memref) -> memref {modifiers = #acc} + +// ----- + +%value = memref.alloc() : memref +// expected-error @below {{invalid data clause modifiers: alwaysin}} +%0 = acc.create varPtr(%value : memref) -> memref {modifiers = #acc} diff --git a/mlir/test/Dialect/OpenACC/ops.mlir b/mlir/test/Dialect/OpenACC/ops.mlir index 97278f869534b..c1d8276d904bb 100644 --- a/mlir/test/Dialect/OpenACC/ops.mlir +++ b/mlir/test/Dialect/OpenACC/ops.mlir @@ -921,6 +921,21 @@ func.func @testdataop(%a: memref, %b: memref, %c: memref) -> () { // ----- +func.func @testdataopmodifiers(%a: memref, %b: memref, %c: memref) -> () { + %0 = acc.create varPtr(%a : memref) -> memref {modifiers = #acc} + %1 = acc.copyin varPtr(%b : memref) -> memref {modifiers = #acc} + acc.data dataOperands(%0, %1 : memref, memref) { + } + acc.copyout accPtr(%0 : memref) to varPtr(%a : memref) {modifiers = #acc} + func.return +} +// CHECK: func @testdataopmodifiers(%[[ARGA:.*]]: memref, %[[ARGB:.*]]: memref, %[[ARGC:.*]]: memref) { +// CHECK: %[[CREATEA:.*]] = acc.create varPtr(%[[ARGA]] : memref) -> memref {modifiers = #acc} +// CHECK: %[[COPYINB:.*]] = acc.copyin varPtr(%[[ARGB]] : memref) -> memref {modifiers = #acc} +// CHECK: acc.copyout accPtr(%[[CREATEA]] : memref) to varPtr(%[[ARGA]] : memref) {modifiers = #acc} + +// ----- + func.func @testupdateop(%a: memref, %b: memref, %c: memref) -> () { %i64Value = arith.constant 1 : i64 %i32Value = arith.constant 1 : i32