-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[OpenACC][CIR] Implement 'num_gangs' lowering #137216
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
f24d90d
064fc78
b6e65ae
39691cc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -95,19 +95,41 @@ class OpenACCClauseCIREmitter final | |
| .CaseLower("radeon", mlir::acc::DeviceType::Radeon); | ||
| } | ||
|
|
||
| // Handle a clause affected by the 'device-type' to the point that they need | ||
| // to have the attributes added in the correct/corresponding order, such as | ||
| // 'num_workers' or 'vector_length' on a compute construct. For cases where we | ||
| // don't have an expression 'argument' that needs to be added to an operand | ||
| // and only care about the 'device-type' list, we can use this with 'argument' | ||
| // as 'std::nullopt'. If 'argument' is NOT 'std::nullopt' (that is, has a | ||
| // value), argCollection must also be non-null. For cases where we don't have | ||
| // an argument that needs to be added to an additional one (such as asyncOnly) | ||
| // we can use this with 'argument' as std::nullopt. | ||
| mlir::ArrayAttr handleDeviceTypeAffectedClause( | ||
| mlir::ArrayAttr existingDeviceTypes, | ||
| std::optional<mlir::Value> argument = std::nullopt, | ||
| mlir::MutableOperandRange *argCollection = nullptr) { | ||
| // Overload of this function that only returns the device-types list. | ||
| mlir::ArrayAttr | ||
| handleDeviceTypeAffectedClause(mlir::ArrayAttr existingDeviceTypes) { | ||
| mlir::ValueRange argument; | ||
| mlir::MutableOperandRange range{operation}; | ||
|
|
||
| return handleDeviceTypeAffectedClause(existingDeviceTypes, argument, range); | ||
| } | ||
| // Overload of this function for when 'segments' aren't necessary. | ||
| mlir::ArrayAttr | ||
| handleDeviceTypeAffectedClause(mlir::ArrayAttr existingDeviceTypes, | ||
| mlir::ValueRange argument, | ||
| mlir::MutableOperandRange argCollection) { | ||
| llvm::SmallVector<int32_t> segments; | ||
| assert(argument.size() <= 1 && | ||
| "Overload only for cases where segments don't need to be added"); | ||
| return handleDeviceTypeAffectedClause(existingDeviceTypes, argument, | ||
| argCollection, segments); | ||
| } | ||
|
|
||
| // Handle a clause affected by the 'device_type' to the point that they need | ||
| // to have attributes added in the correct/corresponding order, such as | ||
| // 'num_workers' or 'vector_length' on a compute construct. The 'argument' is | ||
| // a collection of operands that need to be appended to the `argCollection` as | ||
| // we're adding a 'device_type' entry. If there is more than 0 elements in | ||
| // the 'argument', the collection must be non-null, as it is needed to add to | ||
| // it. | ||
| // As some clauses, such as 'num_gangs' or 'wait' require a 'segments' list to | ||
| // be maintained, this takes a list of segments that will be updated with the | ||
| // proper counts as 'argument' elements are added. | ||
| mlir::ArrayAttr | ||
| handleDeviceTypeAffectedClause(mlir::ArrayAttr existingDeviceTypes, | ||
| mlir::ValueRange argument, | ||
| mlir::MutableOperandRange argCollection, | ||
| llvm::SmallVector<int32_t> &segments) { | ||
| llvm::SmallVector<mlir::Attribute> deviceTypes; | ||
|
|
||
| // Collect the 'existing' device-type attributes so we can re-create them | ||
|
|
@@ -126,18 +148,18 @@ class OpenACCClauseCIREmitter final | |
| lastDeviceTypeClause->getArchitectures()) { | ||
| deviceTypes.push_back(mlir::acc::DeviceTypeAttr::get( | ||
| builder.getContext(), decodeDeviceType(arch.getIdentifierInfo()))); | ||
| if (argument) { | ||
| assert(argCollection); | ||
| argCollection->append(*argument); | ||
| if (!argument.empty()) { | ||
| argCollection.append(argument); | ||
| segments.push_back(argument.size()); | ||
| } | ||
| } | ||
| } else { | ||
| // Else, we just add a single for 'none'. | ||
| deviceTypes.push_back(mlir::acc::DeviceTypeAttr::get( | ||
| builder.getContext(), mlir::acc::DeviceType::None)); | ||
| if (argument) { | ||
| assert(argCollection); | ||
| argCollection->append(*argument); | ||
| if (!argument.empty()) { | ||
| argCollection.append(argument); | ||
| segments.push_back(argument.size()); | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -220,7 +242,7 @@ class OpenACCClauseCIREmitter final | |
| mlir::MutableOperandRange range = operation.getNumWorkersMutable(); | ||
| operation.setNumWorkersDeviceTypeAttr(handleDeviceTypeAffectedClause( | ||
| operation.getNumWorkersDeviceTypeAttr(), | ||
| createIntExpr(clause.getIntExpr()), &range)); | ||
| createIntExpr(clause.getIntExpr()), range)); | ||
| } else if constexpr (isOneOfTypes<OpTy, SerialOp>) { | ||
| llvm_unreachable("num_workers not valid on serial"); | ||
| } else { | ||
|
|
@@ -234,7 +256,7 @@ class OpenACCClauseCIREmitter final | |
| mlir::MutableOperandRange range = operation.getVectorLengthMutable(); | ||
| operation.setVectorLengthDeviceTypeAttr(handleDeviceTypeAffectedClause( | ||
| operation.getVectorLengthDeviceTypeAttr(), | ||
| createIntExpr(clause.getIntExpr()), &range)); | ||
| createIntExpr(clause.getIntExpr()), range)); | ||
| } else if constexpr (isOneOfTypes<OpTy, SerialOp>) { | ||
| llvm_unreachable("vector_length not valid on serial"); | ||
| } else { | ||
|
|
@@ -252,7 +274,7 @@ class OpenACCClauseCIREmitter final | |
| mlir::MutableOperandRange range = operation.getAsyncOperandsMutable(); | ||
| operation.setAsyncOperandsDeviceTypeAttr(handleDeviceTypeAffectedClause( | ||
| operation.getAsyncOperandsDeviceTypeAttr(), | ||
| createIntExpr(clause.getIntExpr()), &range)); | ||
| createIntExpr(clause.getIntExpr()), range)); | ||
| } | ||
| } else { | ||
| // Data, enter data, exit data, update, wait, combined remain. | ||
|
|
@@ -301,6 +323,28 @@ class OpenACCClauseCIREmitter final | |
| } | ||
| } | ||
|
|
||
| void VisitNumGangsClause(const OpenACCNumGangsClause &clause) { | ||
| if constexpr (isOneOfTypes<OpTy, ParallelOp, KernelsOp>) { | ||
| llvm::SmallVector<mlir::Value> values; | ||
|
|
||
| for (const Expr *E : clause.getIntExprs()) | ||
| values.push_back(createIntExpr(E)); | ||
|
|
||
| llvm::SmallVector<int32_t> segments; | ||
| if (operation.getNumGangsSegments()) | ||
| llvm::copy(*operation.getNumGangsSegments(), | ||
| std::back_inserter(segments)); | ||
|
|
||
| mlir::MutableOperandRange range = operation.getNumGangsMutable(); | ||
| operation.setNumGangsDeviceTypeAttr(handleDeviceTypeAffectedClause( | ||
| operation.getNumGangsDeviceTypeAttr(), values, range, segments)); | ||
| operation.setNumGangsSegments(llvm::ArrayRef<int32_t>{segments}); | ||
| } else { | ||
| // combined remains. | ||
|
||
| return clauseNotImplemented(clause); | ||
| } | ||
| } | ||
|
|
||
| void VisitDefaultAsyncClause(const OpenACCDefaultAsyncClause &clause) { | ||
| if constexpr (isOneOfTypes<OpTy, SetOp>) { | ||
| operation.getDefaultAsyncMutable().append( | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's not clear to me what
segmentsis. Can you add more detail in the comment explaining? It looks like you're pushing the number of arguments in the clause being handled? Can there be a segment with zero arguments between non-zero segments?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm.. 'segments' are a little weird, they are a little bit MLIR/OpenACC-Dialect specific perhaps. I'll try to improve the comment.
As far as zero-arguments between non-zero segments, my understanding is no.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fantastic! Thanks for the expanded explanation.