diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp index 78e6ebb523a46..46db5d3fdca3b 100644 --- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp @@ -464,27 +464,39 @@ LogicalResult GPUShuffleConversion::matchAndRewrite( template static Value createGroupReduceOpImpl(OpBuilder &builder, Location loc, - Value arg, bool isGroup, bool isUniform) { + Value arg, bool isGroup, bool isUniform, + std::optional clusterSize) { Type type = arg.getType(); auto scope = mlir::spirv::ScopeAttr::get(builder.getContext(), isGroup ? spirv::Scope::Workgroup : spirv::Scope::Subgroup); - auto groupOp = spirv::GroupOperationAttr::get(builder.getContext(), - spirv::GroupOperation::Reduce); + auto groupOp = spirv::GroupOperationAttr::get( + builder.getContext(), clusterSize.has_value() + ? spirv::GroupOperation::ClusteredReduce + : spirv::GroupOperation::Reduce); if (isUniform) { return builder.create(loc, type, scope, groupOp, arg) .getResult(); } - return builder.create(loc, type, scope, groupOp, arg, Value{}) + + Value clusterSizeValue; + if (clusterSize.has_value()) + clusterSizeValue = builder.create( + loc, builder.getI32Type(), + builder.getIntegerAttr(builder.getI32Type(), *clusterSize)); + + return builder + .create(loc, type, scope, groupOp, arg, clusterSizeValue) .getResult(); } -static std::optional createGroupReduceOp(OpBuilder &builder, - Location loc, Value arg, - gpu::AllReduceOperation opType, - bool isGroup, bool isUniform) { +static std::optional +createGroupReduceOp(OpBuilder &builder, Location loc, Value arg, + gpu::AllReduceOperation opType, bool isGroup, + bool isUniform, std::optional clusterSize) { enum class ElemType { Float, Boolean, Integer }; - using FuncT = Value (*)(OpBuilder &, Location, Value, bool, bool); + using FuncT = Value (*)(OpBuilder &, Location, Value, bool, bool, + std::optional); struct OpHandler { gpu::AllReduceOperation kind; ElemType elemType; @@ -548,7 +560,7 @@ static std::optional createGroupReduceOp(OpBuilder &builder, for (const OpHandler &handler : handlers) if (handler.kind == opType && elementType == handler.elemType) - return handler.func(builder, loc, arg, isGroup, isUniform); + return handler.func(builder, loc, arg, isGroup, isUniform, clusterSize); return std::nullopt; } @@ -571,7 +583,7 @@ class GPUAllReduceConversion final auto result = createGroupReduceOp(rewriter, op.getLoc(), adaptor.getValue(), *opType, - /*isGroup*/ true, op.getUniform()); + /*isGroup*/ true, op.getUniform(), std::nullopt); if (!result) return failure(); @@ -589,16 +601,17 @@ class GPUSubgroupReduceConversion final LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - if (op.getClusterSize()) + if (op.getClusterStride() > 1) { return rewriter.notifyMatchFailure( - op, "lowering for clustered reduce not implemented"); + op, "lowering for cluster stride > 1 is not implemented"); + } if (!isa(adaptor.getValue().getType())) return rewriter.notifyMatchFailure(op, "reduction type is not a scalar"); - auto result = createGroupReduceOp(rewriter, op.getLoc(), adaptor.getValue(), - adaptor.getOp(), - /*isGroup=*/false, adaptor.getUniform()); + auto result = createGroupReduceOp( + rewriter, op.getLoc(), adaptor.getValue(), adaptor.getOp(), + /*isGroup=*/false, adaptor.getUniform(), op.getClusterSize()); if (!result) return failure(); diff --git a/mlir/test/Conversion/GPUToSPIRV/reductions.mlir b/mlir/test/Conversion/GPUToSPIRV/reductions.mlir index ae834b9915d50..e7e0fa296c98a 100644 --- a/mlir/test/Conversion/GPUToSPIRV/reductions.mlir +++ b/mlir/test/Conversion/GPUToSPIRV/reductions.mlir @@ -789,3 +789,44 @@ gpu.module @kernels { } } } + +// ----- + +module attributes { + gpu.container_module, + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> +} { + +gpu.module @kernels { + // CHECK-LABEL: spirv.func @test_subgroup_reduce_clustered + // CHECK-SAME: (%[[ARG:.*]]: f32) + // CHECK: %[[CLUSTER_SIZE:.*]] = spirv.Constant 8 : i32 + gpu.func @test_subgroup_reduce_clustered(%arg : f32) kernel + attributes {spirv.entry_point_abi = #spirv.entry_point_abi} { + // CHECK: %{{.*}} = spirv.GroupNonUniformFAdd %[[ARG]] cluster_size(%[[CLUSTER_SIZE]]) : f32, i32 -> f32 + %reduced = gpu.subgroup_reduce add %arg cluster(size = 8) : (f32) -> (f32) + gpu.return + } +} + +} + +// ----- + +// Subgrop reduce with cluster stride > 1 is not yet supported. + +module attributes { + gpu.container_module, + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> +} { + +gpu.module @kernels { + gpu.func @test_invalid_subgroup_reduce_clustered_stride(%arg : f32) kernel + attributes {spirv.entry_point_abi = #spirv.entry_point_abi} { + // expected-error @+1 {{failed to legalize operation 'gpu.subgroup_reduce'}} + %reduced = gpu.subgroup_reduce add %arg cluster(size = 8, stride = 2) : (f32) -> (f32) + gpu.return + } +} + +}