@@ -921,6 +921,23 @@ def NVVM_Barrier0Op : NVVM_Op<"barrier0"> {
921921 }];
922922}
923923
924+ // Attrs describing the reduction operations for the barrier operation.
925+ def BarrierReductionPopc : I32EnumAttrCase<"POPC", 0, "popc">;
926+ def BarrierReductionAnd : I32EnumAttrCase<"AND", 1, "and">;
927+ def BarrierReductionOr : I32EnumAttrCase<"OR", 2, "or">;
928+
929+ def BarrierReduction
930+ : I32EnumAttr<"BarrierReduction", "NVVM barrier reduction operation",
931+ [BarrierReductionPopc, BarrierReductionAnd,
932+ BarrierReductionOr]> {
933+ let genSpecializedAttr = 0;
934+ let cppNamespace = "::mlir::NVVM";
935+ }
936+ def BarrierReductionAttr
937+ : EnumAttr<NVVM_Dialect, BarrierReduction, "reduction"> {
938+ let assemblyFormat = "`<` $value `>`";
939+ }
940+
924941def NVVM_BarrierOp : NVVM_Op<"barrier", [AttrSizedOperandSegments]> {
925942 let summary = "CTA Barrier Synchronization Op";
926943 let description = [{
@@ -935,6 +952,9 @@ def NVVM_BarrierOp : NVVM_Op<"barrier", [AttrSizedOperandSegments]> {
935952 - `numberOfThreads`: Specifies the number of threads participating in the barrier.
936953 When specified, the value must be a multiple of the warp size. If not specified,
937954 all threads in the CTA participate in the barrier.
955+ - `reductionOp`: specifies the reduction operation (`popc`, `and`, `or`).
956+ - `reductionPredicate`: specifies the predicate to be used with the
957+ `reductionOp`.
938958
939959 The barrier operation guarantees that when the barrier completes, prior memory
940960 accesses requested by participating threads are performed relative to all threads
@@ -951,31 +971,37 @@ def NVVM_BarrierOp : NVVM_Op<"barrier", [AttrSizedOperandSegments]> {
951971 [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-bar)
952972 }];
953973
954- let arguments = (ins
955- Optional<I32>:$barrierId,
956- Optional<I32>:$numberOfThreads);
974+ let extraClassDeclaration = [{
975+ static mlir::NVVM::IDArgPair
976+ getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
977+ llvm::IRBuilderBase& builder);
978+ }];
979+
980+ let arguments = (ins Optional<I32>:$barrierId, Optional<I32>:$numberOfThreads,
981+ OptionalAttr<BarrierReductionAttr>:$reductionOp,
982+ Optional<I32>:$reductionPredicate);
957983 string llvmBuilder = [{
958- llvm::Value *id = $barrierId ? $barrierId : builder.getInt32(0);
959- if ($numberOfThreads)
960- createIntrinsicCall(
961- builder, llvm::Intrinsic::nvvm_barrier_cta_sync_aligned_count,
962- {id, $numberOfThreads});
963- else
964- createIntrinsicCall(
965- builder, llvm::Intrinsic::nvvm_barrier_cta_sync_aligned_all, {id});
984+ auto [id, args] = NVVM::BarrierOp::getIntrinsicIDAndArgs(
985+ *op, moduleTranslation, builder);
986+ if ($reductionOp)
987+ $res = createIntrinsicCall(builder, id, args);
988+ else
989+ createIntrinsicCall(builder, id, args);
966990 }];
991+ let results = (outs Optional<I32>:$res);
992+
967993 let hasVerifier = 1;
968994
969- let assemblyFormat = "(`id` `=` $barrierId^)? (`number_of_threads` `=` $numberOfThreads^)? attr-dict";
995+ let assemblyFormat =
996+ "(`id` `=` $barrierId^)? (`number_of_threads` `=` $numberOfThreads^)? "
997+ "($reductionOp^ $reductionPredicate)? (`->` type($res)^)? attr-dict";
970998
971- let builders = [
972- OpBuilder<(ins), [{
973- return build($_builder, $_state, Value{}, Value{});
999+ let builders = [OpBuilder<(ins), [{
1000+ return build($_builder, $_state, TypeRange{}, Value{}, Value{}, {}, Value{});
9741001 }]>,
975- OpBuilder<(ins "Value":$barrierId), [{
976- return build($_builder, $_state, barrierId, Value{});
977- }]>
978- ];
1002+ OpBuilder<(ins "Value":$barrierId), [{
1003+ return build($_builder, $_state, TypeRange{}, barrierId, Value{}, {}, Value{});
1004+ }]>];
9791005}
9801006
9811007def NVVM_BarrierArriveOp : NVVM_PTXBuilder_Op<"barrier.arrive">
0 commit comments