From 377841b9a9b7192160ad4bc59be1a41276edc7b7 Mon Sep 17 00:00:00 2001 From: fairywreath Date: Sun, 25 May 2025 03:31:17 -0400 Subject: [PATCH 1/4] [mlir][spirv] Implement lowering of `gpu.subgroup_reduce` with cluster size for SPIRV --- mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp | 45 ++++++++++++------- .../Conversion/GPUToSPIRV/reductions.mlir | 41 +++++++++++++++++ 2 files changed, 70 insertions(+), 16 deletions(-) diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp index 3cc64b82950b5..f42605a6e8ce1 100644 --- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp @@ -464,27 +464,39 @@ LogicalResult GPUShuffleConversion::matchAndRewrite( template static Value createGroupReduceOpImpl(OpBuilder &builder, Location loc, - Value arg, bool isGroup, bool isUniform) { + Value arg, bool isGroup, bool isUniform, + std::optional clusterSize) { Type type = arg.getType(); auto scope = mlir::spirv::ScopeAttr::get(builder.getContext(), isGroup ? spirv::Scope::Workgroup : spirv::Scope::Subgroup); - auto groupOp = spirv::GroupOperationAttr::get(builder.getContext(), - spirv::GroupOperation::Reduce); + auto groupOp = spirv::GroupOperationAttr::get( + builder.getContext(), clusterSize.has_value() + ? spirv::GroupOperation::ClusteredReduce + : spirv::GroupOperation::Reduce); if (isUniform) { return builder.create(loc, type, scope, groupOp, arg) .getResult(); } - return builder.create(loc, type, scope, groupOp, arg, Value{}) + + Value clusterSizeValue = + clusterSize.has_value() + ? builder.create( + loc, builder.getI32Type(), + builder.getIntegerAttr(builder.getI32Type(), *clusterSize)) + : Value{}; + return builder + .create(loc, type, scope, groupOp, arg, clusterSizeValue) .getResult(); } -static std::optional createGroupReduceOp(OpBuilder &builder, - Location loc, Value arg, - gpu::AllReduceOperation opType, - bool isGroup, bool isUniform) { +static std::optional +createGroupReduceOp(OpBuilder &builder, Location loc, Value arg, + gpu::AllReduceOperation opType, bool isGroup, + bool isUniform, std::optional clusterSize) { enum class ElemType { Float, Boolean, Integer }; - using FuncT = Value (*)(OpBuilder &, Location, Value, bool, bool); + using FuncT = Value (*)(OpBuilder &, Location, Value, bool, bool, + std::optional); struct OpHandler { gpu::AllReduceOperation kind; ElemType elemType; @@ -548,7 +560,7 @@ static std::optional createGroupReduceOp(OpBuilder &builder, for (const OpHandler &handler : handlers) if (handler.kind == opType && elementType == handler.elemType) - return handler.func(builder, loc, arg, isGroup, isUniform); + return handler.func(builder, loc, arg, isGroup, isUniform, clusterSize); return std::nullopt; } @@ -571,7 +583,7 @@ class GPUAllReduceConversion final auto result = createGroupReduceOp(rewriter, op.getLoc(), adaptor.getValue(), *opType, - /*isGroup*/ true, op.getUniform()); + /*isGroup*/ true, op.getUniform(), std::nullopt); if (!result) return failure(); @@ -589,16 +601,17 @@ class GPUSubgroupReduceConversion final LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - if (op.getClusterSize()) + if (op.getClusterStride() > 1) { return rewriter.notifyMatchFailure( - op, "lowering for clustered reduce not implemented"); + op, "lowering for cluster stride > 1 is not implemented"); + } if (!isa(adaptor.getValue().getType())) return rewriter.notifyMatchFailure(op, "reduction type is not a scalar"); - auto result = createGroupReduceOp(rewriter, op.getLoc(), adaptor.getValue(), - adaptor.getOp(), - /*isGroup=*/false, adaptor.getUniform()); + auto result = createGroupReduceOp( + rewriter, op.getLoc(), adaptor.getValue(), adaptor.getOp(), + /*isGroup=*/false, adaptor.getUniform(), op.getClusterSize()); if (!result) return failure(); diff --git a/mlir/test/Conversion/GPUToSPIRV/reductions.mlir b/mlir/test/Conversion/GPUToSPIRV/reductions.mlir index ae834b9915d50..08d9b094a5303 100644 --- a/mlir/test/Conversion/GPUToSPIRV/reductions.mlir +++ b/mlir/test/Conversion/GPUToSPIRV/reductions.mlir @@ -789,3 +789,44 @@ gpu.module @kernels { } } } + +// ----- + +module attributes { + gpu.container_module, + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> +} { + +gpu.module @kernels { + // CHECK-LABEL: spirv.func @test + // CHECK-SAME: (%[[ARG:.*]]: f32) + // CHECK: %[[CLUSTER_SIZE:.*]] = spirv.Constant 8 : i32 + gpu.func @test22(%arg : f32) kernel + attributes {spirv.entry_point_abi = #spirv.entry_point_abi} { + // CHECK: %{{.*}} = spirv.GroupNonUniformFAdd %[[ARG]] cluster_size(%[[CLUSTER_SIZE]]) : f32, i32 -> f32 + %reduced = gpu.subgroup_reduce add %arg cluster(size = 8) : (f32) -> (f32) + gpu.return + } +} + +} + +// ----- + +// Subgrop reduce with cluster stride > 1 is not yet supported. + +module attributes { + gpu.container_module, + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> +} { + +gpu.module @kernels { + gpu.func @test22(%arg : f32) kernel + attributes {spirv.entry_point_abi = #spirv.entry_point_abi} { + // expected-error @+1 {{failed to legalize operation 'gpu.subgroup_reduce'}} + %reduced = gpu.subgroup_reduce add %arg cluster(size = 8, stride = 2) : (f32) -> (f32) + gpu.return + } +} + +} From 9dbd3d26f2b0e2074cf6c16521718b36d8306586 Mon Sep 17 00:00:00 2001 From: fairywreath Date: Sun, 25 May 2025 20:22:42 -0400 Subject: [PATCH 2/4] Use better function names in test --- mlir/test/Conversion/GPUToSPIRV/reductions.mlir | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mlir/test/Conversion/GPUToSPIRV/reductions.mlir b/mlir/test/Conversion/GPUToSPIRV/reductions.mlir index 08d9b094a5303..e7e0fa296c98a 100644 --- a/mlir/test/Conversion/GPUToSPIRV/reductions.mlir +++ b/mlir/test/Conversion/GPUToSPIRV/reductions.mlir @@ -798,10 +798,10 @@ module attributes { } { gpu.module @kernels { - // CHECK-LABEL: spirv.func @test + // CHECK-LABEL: spirv.func @test_subgroup_reduce_clustered // CHECK-SAME: (%[[ARG:.*]]: f32) // CHECK: %[[CLUSTER_SIZE:.*]] = spirv.Constant 8 : i32 - gpu.func @test22(%arg : f32) kernel + gpu.func @test_subgroup_reduce_clustered(%arg : f32) kernel attributes {spirv.entry_point_abi = #spirv.entry_point_abi} { // CHECK: %{{.*}} = spirv.GroupNonUniformFAdd %[[ARG]] cluster_size(%[[CLUSTER_SIZE]]) : f32, i32 -> f32 %reduced = gpu.subgroup_reduce add %arg cluster(size = 8) : (f32) -> (f32) @@ -821,7 +821,7 @@ module attributes { } { gpu.module @kernels { - gpu.func @test22(%arg : f32) kernel + gpu.func @test_invalid_subgroup_reduce_clustered_stride(%arg : f32) kernel attributes {spirv.entry_point_abi = #spirv.entry_point_abi} { // expected-error @+1 {{failed to legalize operation 'gpu.subgroup_reduce'}} %reduced = gpu.subgroup_reduce add %arg cluster(size = 8, stride = 2) : (f32) -> (f32) From a865235355de587f8daff8f6411e6c936fb69fdd Mon Sep 17 00:00:00 2001 From: fairywreath Date: Tue, 3 Jun 2025 15:39:50 -0600 Subject: [PATCH 3/4] Use if statement instead of ternary --- mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp index f42605a6e8ce1..67c82f4b9653a 100644 --- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp @@ -479,12 +479,12 @@ static Value createGroupReduceOpImpl(OpBuilder &builder, Location loc, .getResult(); } - Value clusterSizeValue = - clusterSize.has_value() - ? builder.create( - loc, builder.getI32Type(), - builder.getIntegerAttr(builder.getI32Type(), *clusterSize)) - : Value{}; + Value clusterSizeValue = {}; + if (clusterSize.has_value()) + clusterSizeValue = builder.create( + loc, builder.getI32Type(), + builder.getIntegerAttr(builder.getI32Type(), *clusterSize)); + return builder .create(loc, type, scope, groupOp, arg, clusterSizeValue) .getResult(); From bc262552b01aebf3dac4067a92b9ad2cc3ae1278 Mon Sep 17 00:00:00 2001 From: fairywreath Date: Tue, 3 Jun 2025 21:16:49 -0400 Subject: [PATCH 4/4] Address review comment --- mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp index 67c82f4b9653a..a6fcd741cbf34 100644 --- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp @@ -479,7 +479,7 @@ static Value createGroupReduceOpImpl(OpBuilder &builder, Location loc, .getResult(); } - Value clusterSizeValue = {}; + Value clusterSizeValue; if (clusterSize.has_value()) clusterSizeValue = builder.create( loc, builder.getI32Type(),