Skip to content

Commit 1a13c77

Browse files
authored
[GPU][DT] Fix matmul narrow dim selection (iree-org#21764)
The old logic: ``` if (ShapedType::isDynamic(n) || m < n) { ... } if (ShapedType::isDynamic(m) || n < m) { ... } ``` could incorrectly select the narrow dimension when `m` is dynamic (represented by `INT64_MIN`). This case should be handled by the second `if`, but it is accidentally captured by the first `if`, since `m < n` evaluates as true for a dynamic `m`. This PR also fixes the iterationSizes issue that caused compilation failures in llama with data tiling. --------- Signed-off-by: Yu-Zhewen <[email protected]>
1 parent 73c0d4f commit 1a13c77

File tree

2 files changed

+63
-2
lines changed

2 files changed

+63
-2
lines changed

compiler/src/iree/compiler/Codegen/Common/test/materialize_encoding_gfx942.mlir

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,54 @@ func.func @set_encoding_ACC_unroll8x8x4_MFMA_F32_16x16x4_F32() {
236236

237237
// -----
238238

239+
#encoding = #iree_encoding.encoding<operand_index = 2, op_type = matmul, element_types = [f32, f32, f32],
240+
user_indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
241+
iteration_sizes = [?, 513, ?]>
242+
func.func @set_encoding_ACC_dynamic_M_MFMA_F32_16x16x4_F32(%arg0 : tensor<?x513xf32>) -> tensor<?x513xf32, #encoding> {
243+
%0 = iree_encoding.set_encoding %arg0 : tensor<?x513xf32> -> tensor<?x513xf32, #encoding>
244+
return %0 : tensor<?x513xf32, #encoding>
245+
}
246+
247+
// CHECK-LABEL: func.func @set_encoding_ACC_dynamic_M_MFMA_F32_16x16x4_F32
248+
// CHECK: %[[PACK:.*]] = linalg.pack %{{.+}} padding_value(%{{.+}} : f32)
249+
// CHECK-SAME: outer_dims_perm = [0, 1]
250+
// CHECK-SAME: inner_dims_pos = [0, 1]
251+
// CHECK-SAME: inner_tiles = [128, 128]
252+
// CHECK-SAME: : tensor<?x513xf32> -> tensor<?x5x128x128xf32>
253+
// CHECK: %[[EXPAND:.*]] = tensor.expand_shape %[[PACK]]
254+
// CHECK-SAME : tensor<?x5x128x128xf32> into tensor<?x5x4x4x2x4x16x8xf32>
255+
// CHECK: %[[TRANSPOSE:.*]] = linalg.transpose
256+
// CHECK-SAME: ins(%[[EXPAND]] : tensor<?x5x4x4x2x4x16x8xf32>)
257+
// CHECK-SAME: outs({{.*}} : tensor<?x5x4x2x8x4x16x4xf32>)
258+
// CHECK-SAME: permutation = [0, 1, 2, 4, 7, 3, 6, 5]
259+
// CHECK: return %[[TRANSPOSE]]
260+
261+
// -----
262+
263+
#encoding = #iree_encoding.encoding<operand_index = 2, op_type = matmul, element_types = [f32, f32, f32],
264+
user_indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
265+
iteration_sizes = [255, ?, ?]>
266+
func.func @set_encoding_ACC_dynamic_N_MFMA_F32_16x16x4_F32(%arg0 : tensor<255x?xf32>) -> tensor<255x?xf32, #encoding> {
267+
%0 = iree_encoding.set_encoding %arg0 : tensor<255x?xf32> -> tensor<255x?xf32, #encoding>
268+
return %0 : tensor<255x?xf32, #encoding>
269+
}
270+
271+
// CHECK-LABEL: func.func @set_encoding_ACC_dynamic_N_MFMA_F32_16x16x4_F32
272+
// CHECK: %[[PACK:.*]] = linalg.pack %{{.+}} padding_value(%{{.+}} : f32)
273+
// CHECK-SAME: outer_dims_perm = [0, 1]
274+
// CHECK-SAME: inner_dims_pos = [0, 1]
275+
// CHECK-SAME: inner_tiles = [128, 128]
276+
// CHECK-SAME: : tensor<255x?xf32> -> tensor<2x?x128x128xf32>
277+
// CHECK: %[[EXPAND:.*]] = tensor.expand_shape %[[PACK]]
278+
// CHECK-SAME : tensor<2x?x128x128xf32> into tensor<2x?x4x8x4x4x16x2xf32>
279+
// CHECK: %[[TRANSPOSE:.*]] = linalg.transpose
280+
// CHECK-SAME: ins(%[[EXPAND]] : tensor<2x?x4x8x4x4x16x2xf32>)
281+
// CHECK-SAME: outs({{.*}} : tensor<2x?x4x8x2x4x16x4xf32>)
282+
// CHECK-SAME: permutation = [0, 1, 5, 3, 7, 2, 6, 4]
283+
// CHECK: return %[[TRANSPOSE]]
284+
285+
// -----
286+
239287
#encoding = #iree_encoding.encoding<operand_index = 2, op_type = matmul, element_types = [f32, f32, f32],
240288
user_indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
241289
iteration_sizes = [255, 513, ?]>

compiler/src/iree/compiler/Dialect/Encoding/Utils/Utils.cpp

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,17 +62,30 @@ MatmulNarrowDim getPo2MatmulNarrowDim(EncodingAttr encoding) {
6262
// and vecmat, so set to 1 if empty.
6363
const int64_t m = cDims.m.empty() ? 1 : iterationSizes[cDims.m[0]];
6464
const int64_t n = cDims.n.empty() ? 1 : iterationSizes[cDims.n[0]];
65+
66+
// If both dimensions are dynamic, return empty.
6567
if (ShapedType::isDynamic(m) && ShapedType::isDynamic(n)) {
6668
return {};
6769
}
68-
if (ShapedType::isDynamic(n) || m < n) {
70+
// If only one dimension is dynamic, pick the other as the narrow dimension.
71+
if (ShapedType::isDynamic(m)) {
72+
return {MatmulNarrowDim::Dim::N,
73+
static_cast<int64_t>(llvm::PowerOf2Ceil(n))};
74+
}
75+
if (ShapedType::isDynamic(n)) {
6976
return {MatmulNarrowDim::Dim::M,
7077
static_cast<int64_t>(llvm::PowerOf2Ceil(m))};
7178
}
72-
if (ShapedType::isDynamic(m) || n < m) {
79+
// If Both dimensions are static, pick the smaller one.
80+
if (n < m) {
7381
return {MatmulNarrowDim::Dim::N,
7482
static_cast<int64_t>(llvm::PowerOf2Ceil(n))};
7583
}
84+
if (m < n) {
85+
return {MatmulNarrowDim::Dim::M,
86+
static_cast<int64_t>(llvm::PowerOf2Ceil(m))};
87+
}
88+
// If dimensions are static and equal, return empty.
7689
return {};
7790
}
7891

0 commit comments

Comments
 (0)