Skip to content

Commit 8da43fe

Browse files
authored
[Codegen] Disallow padding more skinny matmuls (iree-org#20289)
Fix up the threshold from iree-org#20284 to disallow other skinny matmuls. Add a TODO to refactor this in the future. Also add debug prints for padding with `--debug-only=iree-encoding-attrs`. --------- Signed-off-by: Jakub Kuderski <[email protected]>
1 parent 6d9f736 commit 8da43fe

File tree

3 files changed

+15
-6
lines changed

3 files changed

+15
-6
lines changed

compiler/src/iree/compiler/Codegen/ExternalInterfaces/GPUEncodingExternalModels.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -449,7 +449,7 @@ struct GPUPadEncodingLayoutResolverAttrInterface final
449449
return noPaddingAttr;
450450
}
451451

452-
// Bail out on matvec / vecmat problems.
452+
// Bail out on matvec / vecmat and skinny matmul problems.
453453
{
454454
int64_t parallelDimSize = 1;
455455
ArrayRef<unsigned> parallelDims =
@@ -466,10 +466,11 @@ struct GPUPadEncodingLayoutResolverAttrInterface final
466466
}
467467
}
468468

469-
static constexpr int64_t kMatVecThreshold = 16;
469+
// TODO(#19897): Use `getMatmulNarrowDim`.
470+
static constexpr int64_t kSkinnyMatmulThreshold = 64;
470471
if (!ShapedType::isDynamic(parallelDimSize) &&
471-
parallelDimSize < kMatVecThreshold) {
472-
// This matmul is more similar to a matvec, do not pad.
472+
parallelDimSize < kSkinnyMatmulThreshold) {
473+
// This matmul is skinny, do not pad.
473474
return noPaddingAttr;
474475
}
475476
}

compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingAttrs.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include "iree/compiler/Dialect/Encoding/IR/EncodingTypes.h"
88

99
#include "llvm/ADT/SmallVector.h"
10+
#include "llvm/Support/Debug.h"
1011
#include "mlir/Dialect/Arith/IR/Arith.h"
1112
#include "mlir/IR/AffineMap.h"
1213
#include "mlir/IR/Attributes.h"
@@ -19,6 +20,10 @@
1920

2021
#include <cassert>
2122

23+
#define DEBUG_TYPE "iree-encoding-attrs"
24+
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
25+
#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
26+
2227
namespace mlir::iree_compiler::IREE::Encoding {
2328

2429
//===---------------------------------------------------------------------===//
@@ -338,6 +343,9 @@ Value PadEncodingLayoutAttr::calculateStorageSizeInBytes(
338343
ValueRange dynamicDims) const {
339344
ArrayRef<int32_t> padding = getPadding().asArrayRef();
340345
assert(padding.size() == type.getRank() && "Invalid padding");
346+
LLVM_DEBUG(if (llvm::any_of(padding, [](int32_t x) { return x != 0; })) {
347+
llvm::dbgs() << "Non-zero padding: " << type << "\n";
348+
});
341349

342350
const int64_t elementSize = getRoundedElementByteWidth(type.getElementType());
343351
int64_t staticProduct = elementSize;

compiler/src/iree/compiler/Dialect/Stream/Transforms/test/specialize_encodings.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ util.func public @with_pad_encoding(%arg0: index, %arg1: index, %scalar_f32 : f3
103103
%2 = stream.tensor.empty on(#hal.device.affinity<@device_a>) : tensor<4096x1337xf16, #encodingA>{} in !stream.resource<*>{%arg1}
104104
%3 = stream.tensor.empty on(#hal.device.affinity<@device_a>) : tensor<4096x4095xf16, #encodingA>{} in !stream.resource<*>{%arg1}
105105
%4 = stream.tensor.empty on(#hal.device.affinity<@device_a>) : tensor<4096x250xf16, #encodingA>{} in !stream.resource<*>{%arg1}
106-
%5 = stream.tensor.empty on(#hal.device.affinity<@device_a>) : tensor<15x4096xf16, #encodingA>{} in !stream.resource<*>{%arg1}
106+
%5 = stream.tensor.empty on(#hal.device.affinity<@device_a>) : tensor<60x4096xf16, #encodingA>{} in !stream.resource<*>{%arg1}
107107
%6 = stream.tensor.empty on(#hal.device.affinity<@device_a>) : tensor<1x4096xf16, #encodingB>{} in !stream.resource<*>{%arg1}
108108
%7 = stream.tensor.empty on(#hal.device.affinity<@device_a>) : tensor<?x4096xf16, #encodingA>{%arg0} in !stream.resource<*>{%arg1}
109109
%8 = stream.tensor.empty on(#hal.device.affinity<@device_a>) : tensor<?x?xf16, #encodingA>{%arg0, %arg1} in !stream.resource<*>{%arg1}
@@ -128,7 +128,7 @@ util.func public @with_pad_encoding(%arg0: index, %arg1: index, %scalar_f32 : f3
128128
// CHECK: stream.tensor.empty {{.*}} : tensor<4096x1337xf16, #[[$PAD_LHS_1]]>
129129
// CHECK: stream.tensor.empty {{.*}} : tensor<4096x4095xf16, #[[$PAD_LHS_2]]>
130130
// CHECK: stream.tensor.empty {{.*}} : tensor<4096x250xf16, #[[$NO_PAD_LHS]]>
131-
// CHECK: stream.tensor.empty {{.*}} : tensor<15x4096xf16, #[[$NO_PAD_LHS]]>
131+
// CHECK: stream.tensor.empty {{.*}} : tensor<60x4096xf16, #[[$NO_PAD_LHS]]>
132132
// CHECK: stream.tensor.empty {{.*}} : tensor<1x4096xf16, #[[$NO_PAD_RHS]]>
133133
// CHECK: stream.tensor.empty {{.*}} : tensor<?x4096xf16, #[[$PAD_LHS_0]]>
134134
// CHECK: stream.tensor.empty {{.*}} : tensor<?x?xf16, #[[$NO_PAD_LHS]]>

0 commit comments

Comments
 (0)