Skip to content

Commit b4e3694

Browse files
[GPU] Bail out in matmul TileAndFuse config for unaligned dynamic shapes (#20622)
We cant always handle such shapes with the padding method used by the pipeline. Fixes : #20581 --------- Signed-off-by: Nirvedh <[email protected]>
1 parent 660cf92 commit b4e3694

File tree

2 files changed

+74
-10
lines changed

2 files changed

+74
-10
lines changed

compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -205,30 +205,43 @@ getMatmulLoweringConfigAndWorkgroupSize(SmallVector<int64_t> bounds,
205205
return failure();
206206
}
207207

208+
// We can support unaligned shapes as long as there are no dynamic dimensions
209+
// as finding padding bounds for dynamic dimensions is not guaranteed.
210+
// TODO(nirvedhmeshram): Add support so that we can find the bounds
211+
// information.
212+
bool canSupportUnaligned = true;
213+
208214
// Gather all static M, N, and K dimensions to deduce the MMASchedule. Dynamic
209215
// dimensions will be tiled to 1 in workgroup tiling, so they are ignored when
210216
// computing an MMA schedule.
211217
SmallVector<int64_t> mDims, nDims, kDims, batchDims;
212218
for (int64_t mDim : contractionDims.m) {
213-
if (!ShapedType::isDynamic(bounds[mDim])) {
214-
mDims.push_back(mDim);
219+
if (ShapedType::isDynamic(bounds[mDim])) {
220+
canSupportUnaligned = false;
221+
continue;
215222
}
223+
mDims.push_back(mDim);
216224
}
217225
for (int64_t nDim : contractionDims.n) {
218-
if (!ShapedType::isDynamic(bounds[nDim])) {
219-
nDims.push_back(nDim);
226+
if (ShapedType::isDynamic(bounds[nDim])) {
227+
canSupportUnaligned = false;
228+
continue;
220229
}
230+
nDims.push_back(nDim);
221231
}
222232
for (int64_t kDim : contractionDims.k) {
223-
if (!ShapedType::isDynamic(bounds[kDim])) {
224-
kDims.push_back(kDim);
233+
if (ShapedType::isDynamic(bounds[kDim])) {
234+
canSupportUnaligned = false;
235+
continue;
225236
}
237+
kDims.push_back(kDim);
226238
}
227-
228239
for (int64_t batchDim : contractionDims.batch) {
229-
if (!ShapedType::isDynamic(bounds[batchDim])) {
230-
batchDims.push_back(batchDim);
240+
if (ShapedType::isDynamic(bounds[batchDim])) {
241+
canSupportUnaligned = false;
242+
continue;
231243
}
244+
batchDims.push_back(batchDim);
232245
}
233246

234247
auto getDimBounds = [&](SmallVector<int64_t> dims) -> SmallVector<int64_t> {
@@ -267,7 +280,7 @@ getMatmulLoweringConfigAndWorkgroupSize(SmallVector<int64_t> bounds,
267280
// the GEMM is accumulating (i.e doesnt have a zero fill dpsInit) as that
268281
// buffer currently gets materialized as private memory. We need to add
269282
// missing patterns to fix that.
270-
if (!schedule) {
283+
if (!schedule && canSupportUnaligned) {
271284
LDBG("Attempting to deduce unaligned TileAndFuse MMA schedulee");
272285
mustBeAligned = false;
273286
doCPromotion = true;

compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_tile_and_fuse.mlir

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,57 @@ func.func @unaligned_matmul_with_two_reduce_dim(%arg0: tensor<196x9x4xf32>, %arg
330330

331331
// -----
332332

333+
module {
334+
func.func @aligned_dynamic_matmul_with_two_reduce_dim(%arg0: tensor<192x?x16xf32>, %arg1: tensor<?x16x16xf32>) -> tensor<192x16xf32> {
335+
%cst = arith.constant 0.000000e+00 : f32
336+
%0 = tensor.empty() : tensor<192x16xf32>
337+
%1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<192x16xf32>) -> tensor<192x16xf32>
338+
%2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2)>], iterator_types = ["parallel", "reduction", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<192x?x16xf32>, tensor<?x16x16xf32>) outs(%1 : tensor<192x16xf32>) {
339+
^bb0(%in: f32, %in_0: f32, %out: f32):
340+
%3 = arith.mulf %in, %in_0 : f32
341+
%4 = arith.addf %out, %3 : f32
342+
linalg.yield %4 : f32
343+
} -> tensor<192x16xf32>
344+
return %2 : tensor<192x16xf32>
345+
}
346+
}
347+
348+
// CHECK-LABEL: func.func @aligned_dynamic_matmul_with_two_reduce_dim
349+
// CHECK-SAME: {translation_info = #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [128, 1, 1] subgroup_size = 64
350+
// CHECK: linalg.generic
351+
// CHECK-SAME: {lowering_config = #iree_gpu.lowering_config<{mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x4_F32>
352+
// CHECK-SAME: promote_operands = [0, 1]
353+
// CHECK-SAME: reduction = [0, 1, 0, 4],
354+
// CHECK-SAME: subgroup = [2, 0, 1, 0],
355+
// CHECK-SAME: workgroup = [64, 0, 16, 0]}
356+
357+
// -----
358+
359+
module {
360+
func.func @unaligned_dynamic_matmul_with_two_reduce_dim(%arg0: tensor<196x?x4xf32>, %arg1: tensor<?x16x4xf32>) -> tensor<196x16xf32> {
361+
%cst = arith.constant 0.000000e+00 : f32
362+
%0 = tensor.empty() : tensor<196x16xf32>
363+
%1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<196x16xf32>) -> tensor<196x16xf32>
364+
%2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2)>], iterator_types = ["parallel", "reduction", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<196x?x4xf32>, tensor<?x16x4xf32>) outs(%1 : tensor<196x16xf32>) {
365+
^bb0(%in: f32, %in_0: f32, %out: f32):
366+
%3 = arith.mulf %in, %in_0 : f32
367+
%4 = arith.addf %out, %3 : f32
368+
linalg.yield %4 : f32
369+
} -> tensor<196x16xf32>
370+
return %2 : tensor<196x16xf32>
371+
}
372+
}
373+
374+
// CHECK-LABEL: func.func @unaligned_dynamic_matmul_with_two_reduce_dim
375+
// CHECK-SAME: {translation_info = #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [64, 1, 1] subgroup_size = 64
376+
// CHECK: linalg.generic
377+
// CHECK-SAME: promote_operands = [0, 1]
378+
// CHECK-SAME: reduction = [0, 4, 0, 4],
379+
// CHECK-SAME: thread = [1, 0, 1, 0],
380+
// CHECK-SAME: workgroup = [4, 0, 16, 0]}
381+
382+
// -----
383+
333384
module {
334385
func.func @unaligned_to_intrinsic_batched_matmul_tiling_check(%lhs : tensor<12x577x577xf32>, %rhs : tensor<12x577x1024xf32>) -> tensor<12x577x1024xf32> {
335386
%c0 = arith.constant 0.0 : f32

0 commit comments

Comments
 (0)