Skip to content

Commit 08fa1e0

Browse files
newlingkeshavvinayak01
authored andcommitted
[Codegen] linalg.generic with dynamic reduction dim: use LLVMGPUVectorDistribution. (iree-org#21430)
### Motivation and next steps: We want to deprecate the WarpReduction pipeline. WarpReduction is not used in e2e numerical tests (complete removal, and CI runs all e2e). I intend to get all lit tests that use WarpReduction working through VectorDistribution, and then reconsider total removal ### How dynamic reduction size is supported in this PR: There are places where `LLVMGPUVectorDistribution` selection fails because the reduction size is dynamic. This PR instead chooses a large reduction size to target for configuration. ### Testing: There is a numerical test in CI exercising this path. I also performed numerical testing locally using ```mlir #map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)> #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> func.func @bmm(%arg0: tensor<32x1x?xf16>, %arg1: tensor<32x?x128xf16>) -> tensor<32x1x128xf16> { %cst = arith.constant 0.000000e+00 : f16 %0 = tensor.empty() : tensor<32x1x128xf16> %1 = linalg.fill ins(%cst : f16) outs(%0 : tensor<32x1x128xf16>) -> tensor<32x1x128xf16> %2 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<32x1x?xf16>, tensor<32x?x128xf16>) outs(%1 : tensor<32x1x128xf16>) { ^bb0(%in: f16, %in_0: f16, %out: f16): %3 = arith.mulf %in, %in_0 : f16 %4 = arith.addf %out, %3 : f16 linalg.yield %4 : f16 } -> tensor<32x1x128xf16> return %2 : tensor<32x1x128xf16> } ``` Run with various K values in script ```bash ${IREE_BUILD}/tools/iree-compile --iree-hal-target-device=hip --iree-hip-target=gfx942 -o foo.vmfb export K_VALUES=(1 17 64 1000 9999 99999) for K in "${K_VALUES[@]}"; do echo "Running with K=${K}" ${IREE_BUILD}/tools/iree-run-module \ --device=hip \ --module=foo.vmfb \ --input=32x1x${K}xf16=1 \ --input=32x${K}x128xf16=1 \ --expected_output=32x1x128xf16=${K} done ``` And verified all correct. --------- Signed-off-by: James Newling <[email protected]> Signed-off-by: keshavvinayak01 <[email protected]>
1 parent ec53dbe commit 08fa1e0

File tree

2 files changed

+64
-16
lines changed

2 files changed

+64
-16
lines changed

compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,10 @@ namespace {
140140

141141
using CodeGenPipeline = IREE::Codegen::DispatchLoweringPassPipeline;
142142

143+
// If the size of the reduction dimension is not a dispatch compile-time
144+
// constant, choose a default size that the config should optimize for.
145+
constexpr unsigned kVectorDistributeReductionSizeToTargetIfDynamic = (1 << 31);
146+
143147
// Threshold used to determine whether a matmul dimension is 'very skinny'.
144148
constexpr int64_t kVerySkinnyDimThreshold = 4;
145149

@@ -527,8 +531,9 @@ getVectorDistributeReductionConfig(
527531
}
528532

529533
int64_t lastReductionDimSize = bounds[reductionDims.back()];
534+
530535
if (ShapedType::isDynamic(lastReductionDimSize)) {
531-
return failure();
536+
lastReductionDimSize = kVectorDistributeReductionSizeToTargetIfDynamic;
532537
}
533538
if (lastReductionDimSize % threadLoads != 0) {
534539
return failure();
@@ -834,20 +839,21 @@ setReductionVectorDistributionConfig(IREE::GPU::TargetAttr target,
834839
SmallVector<int64_t> bounds = op.getStaticLoopRanges();
835840
IREE::GPU::TargetWgpAttr wgp = target.getWgp();
836841
int64_t reductionSize = bounds[reductionDims.back()];
842+
837843
if (ShapedType::isDynamic(reductionSize)) {
838-
return failure();
844+
reductionSize = kVectorDistributeReductionSizeToTargetIfDynamic;
839845
}
840846

841-
int64_t numDynamicReductionDims = 0;
847+
bool hasDynamicReductionDim = false;
842848
for (unsigned dim : reductionDims) {
843849
if (ShapedType::isDynamic(bounds[dim])) {
844-
++numDynamicReductionDims;
850+
hasDynamicReductionDim = true;
845851
}
846852
}
847853

848854
int64_t subgroupSize = 0;
849855
for (int s : wgp.getSubgroupSizeChoices().asArrayRef()) {
850-
if (reductionSize % s == 0 || numDynamicReductionDims > 0) {
856+
if (reductionSize % s == 0 || hasDynamicReductionDim) {
851857
subgroupSize = s;
852858
break;
853859
}
@@ -870,7 +876,7 @@ setReductionVectorDistributionConfig(IREE::GPU::TargetAttr target,
870876
const unsigned largestLoadSizeInBits = maxLoadBits.value_or(128);
871877

872878
unsigned threadLoads = largestLoadSizeInBits / *bitWidth;
873-
if (numDynamicReductionDims == 0) {
879+
if (!hasDynamicReductionDim) {
874880
while ((reductionSize / threadLoads) % subgroupSize != 0) {
875881
threadLoads /= 2;
876882
}
@@ -2881,7 +2887,6 @@ static LogicalResult setTransposeConfig(mlir::FunctionOpInterface entryPoint,
28812887
// moving dimension so each thread can execute a vectorized copy of 4
28822888
// contiguous elements at a time from the 32 block.
28832889
std::array<int64_t, 3> workgroupSize = {8, 32, 1};
2884-
28852890
return setOpConfigAndEntryPointFnTranslation(
28862891
entryPoint, linalgOp, tileSizes,
28872892
CodeGenPipeline::LLVMGPUTransposeSharedMem, workgroupSize);

compiler/src/iree/compiler/Codegen/LLVMGPU/test/config_matvec.mlir

Lines changed: 52 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,45 @@
11
// RUN: iree-opt --split-input-file --iree-gpu-test-target=gfx942 --pass-pipeline='builtin.module(iree-llvmgpu-select-lowering-strategy)' %s | FileCheck %s
22
// RUN: iree-opt --split-input-file --iree-gpu-test-target=gfx1100 --pass-pipeline='builtin.module(iree-llvmgpu-select-lowering-strategy)' %s | FileCheck %s --check-prefix=CDNA3
33

4+
5+
#pipeline_layout = #hal.pipeline.layout<constants = 5, bindings = [#hal.pipeline.binding<storage_buffer>, #hal.pipeline.binding<storage_buffer>, #hal.pipeline.binding<storage_buffer>]>
6+
func.func @static_batch_matvec() {
7+
%cst = arith.constant 0.000000e+00 : f16
8+
%0 = hal.interface.constant.load layout(#pipeline_layout) ordinal(0) : i32
9+
%1 = hal.interface.constant.load layout(#pipeline_layout) ordinal(1) : i32
10+
%2 = hal.interface.constant.load layout(#pipeline_layout) ordinal(2) : i32
11+
%3 = arith.index_castui %0 : i32 to index
12+
%4 = arith.index_castui %1 : i32 to index
13+
%5 = arith.index_castui %2 : i32 to index
14+
%6 = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) alignment(64) offset(%5) : !iree_tensor_ext.dispatch.tensor<writeonly:tensor<32x1x128xf16>>
15+
%7 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%3) flags(ReadOnly) : !iree_tensor_ext.dispatch.tensor<readonly:tensor<32x1x1024xf16>>
16+
%8 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%4) flags(ReadOnly) : !iree_tensor_ext.dispatch.tensor<readonly:tensor<32x1024x128xf16>>
17+
%9 = iree_tensor_ext.dispatch.tensor.load %7, offsets = [0, 0, 0], sizes = [32, 1, 1024], strides = [1, 1, 1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<32x1x1024xf16>> -> tensor<32x1x1024xf16>
18+
%10 = iree_tensor_ext.dispatch.tensor.load %8, offsets = [0, 0, 0], sizes = [32, 1024, 128], strides = [1, 1, 1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<32x1024x128xf16>> -> tensor<32x1024x128xf16>
19+
%11 = tensor.empty() : tensor<32x1x128xf16>
20+
%12 = linalg.fill ins(%cst : f16) outs(%11 : tensor<32x1x128xf16>) -> tensor<32x1x128xf16>
21+
%13 = linalg.batch_matmul ins(%9, %10 : tensor<32x1x1024xf16>, tensor<32x1024x128xf16>) outs(%12 : tensor<32x1x128xf16>) -> tensor<32x1x128xf16>
22+
iree_tensor_ext.dispatch.tensor.store %13, %6, offsets = [0, 0, 0], sizes = [32, 1, 128], strides = [1, 1, 1] : tensor<32x1x128xf16> -> !iree_tensor_ext.dispatch.tensor<writeonly:tensor<32x1x128xf16>>
23+
return
24+
}
25+
26+
27+
// CHECK: LLVMGPUWarpReduction
28+
// CDNA3: LLVMGPUTileAndFuse
29+
30+
// We want to deprecate LLVMGPUWarpReduction. Currently LLVMGPUVectorDistribution is not chosen in setReductionVectorDistributionConfig because it fails in 'hasReductionIterator' (which doesn't check specialized ops). This might be an easy whitelisting fix, but I will return to this later (TODO(newling)).
31+
32+
// -----
33+
34+
#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
35+
#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
36+
#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
437
#pipeline_layout = #hal.pipeline.layout<constants = 5, bindings = [
538
#hal.pipeline.binding<storage_buffer>,
639
#hal.pipeline.binding<storage_buffer>,
740
#hal.pipeline.binding<storage_buffer>
841
]>
9-
func.func @dynamic_batch_matvec() {
10-
%c32_i64 = arith.constant 32 : i64
42+
func.func @dynamic_batch_generic_matvec() {
1143
%cst = arith.constant 0.000000e+00 : f16
1244
%0 = hal.interface.constant.load layout(#pipeline_layout) ordinal(0) : i32
1345
%1 = hal.interface.constant.load layout(#pipeline_layout) ordinal(1) : i32
@@ -28,17 +60,28 @@ func.func @dynamic_batch_matvec() {
2860
%16 = iree_tensor_ext.dispatch.tensor.load %14, offsets = [0, 0, 0], sizes = [32, %12, 128], strides = [1, 1, 1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<32x?x128xf16>>{%12} -> tensor<32x?x128xf16>
2961
%17 = tensor.empty() : tensor<32x1x128xf16>
3062
%18 = linalg.fill ins(%cst : f16) outs(%17 : tensor<32x1x128xf16>) -> tensor<32x1x128xf16>
31-
%19 = linalg.batch_matmul ins(%15, %16 : tensor<32x1x?xf16>, tensor<32x?x128xf16>) outs(%18 : tensor<32x1x128xf16>) -> tensor<32x1x128xf16>
63+
%19 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%15, %16 : tensor<32x1x?xf16>, tensor<32x?x128xf16>) outs(%18 : tensor<32x1x128xf16>) {
64+
^bb0(%in: f16, %in_0: f16, %out: f16):
65+
%20 = arith.mulf %in, %in_0 : f16
66+
%21 = arith.addf %out, %20 : f16
67+
linalg.yield %21 : f16
68+
} -> tensor<32x1x128xf16>
3269
iree_tensor_ext.dispatch.tensor.store %19, %10, offsets = [0, 0, 0], sizes = [32, 1, 128], strides = [1, 1, 1] : tensor<32x1x128xf16> -> !iree_tensor_ext.dispatch.tensor<writeonly:tensor<32x1x128xf16>>
3370
return
3471
}
3572

36-
// CDNA3-DAG: #[[$CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[1, 1, 1], [0, 0, 0, 32]{{\]}}>
37-
// CDNA3-DAG: #[[$TRANSLATION:.+]] = #iree_codegen.translation_info<pipeline = LLVMGPUWarpReduction workgroup_size = [32, 1, 1] subgroup_size = 32>
38-
// CDNA3-LABEL: func.func @dynamic_batch_matvec()
39-
// CDNA3-SAME: translation_info = #[[$TRANSLATION]]
40-
// CDNA3: linalg.batch_matmul
41-
// CDNA3-SAME: lowering_config = #[[$CONFIG]]
73+
74+
// CHECK-DAG: #[[$TRANSLATION:.+]] = #iree_codegen.translation_info<pipeline = LLVMGPUVectorDistribute workgroup_size = [1024, 1, 1] subgroup_size = 64
75+
// CHECK-LABEL: func.func @dynamic_batch_generic_matvec()
76+
// CHECK-SAME: translation_info = #[[$TRANSLATION]]
77+
// CHECK: linalg.generic
78+
// CHECK-SAME: attrs = {lowering_config = #iree_gpu.lowering_config<{
79+
// CHECK-SAME: partial_reduction = [0, 0, 0, 8192],
80+
// CHECK-SAME: subgroup_basis = {{\[}}[1, 1, 1, 16], [0, 1, 2, 3]],
81+
// CHECK-SAME: thread = [0, 0, 0, 8],
82+
// CHECK-SAME: workgroup = [1, 1, 1, 0]
83+
84+
// CDNA3: LLVMGPUVectorDistribute
4285

4386
// -----
4487

0 commit comments

Comments
 (0)