Skip to content

Commit e84ea8d

Browse files
authored
[DT][VMVX] Implement VMVXEncodingLayoutAttr. (#19403)
The revision implements the VMVX encoding layout attribute for device code with the new IREE::CPU dialect. Some additional dialect registration is needed because it could create the VMVXEncodingLayoutAttr attribute in the pass pipeline; the dialect needs to be loaded. The main refactoring changes are: - Move chooseMatmulTile to IREECodegen/Utils. We can internalize the method once other CPU materialization logic is moved to their own attribute implementation. - Move the enumerateMatmulTilesVMVX method to CPUEncodingExternalModels.cpp. It is only used by VMVX implementation. On the Codegen utilitiy side (i.e., Codegen/Utils.[h|cpp]), the revision adapts config query methods to use `Attribute`, and share the logic with DictionaryAttr type. Previously, the configuration is wrapped into IREE::HAL::ExecutableTarget attribute. Now we have DictionaryAttr variants and we do not want to duplicate the implementation. Thus, the functions take Attribute input and handle the cases in the implementation. On the data-tiling encoding materialization pass side, it creates the VMVXEncodingLayoutAttr attribute (with the original target configuration). Note that the target configuration is an optional parameter and it is expected to be used within pass scope, but not the final IR output. No additional tests because they are covered by vmvx_materialize_encoding.mlir. --------- Signed-off-by: hanhanW <[email protected]>
1 parent ab3c9bb commit e84ea8d

30 files changed

+750
-209
lines changed

compiler/plugins/target/LLVMCPU/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ iree_compiler_cc_library(
3333
":StaticLibraryGenerator",
3434
"//compiler/plugins/target/LLVMCPU/Builtins",
3535
"//compiler/src/iree/compiler/Codegen/Common",
36+
"//compiler/src/iree/compiler/Codegen/Dialect/CPU/IR:IREECPUDialect",
3637
"//compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR:IREECodegenDialect",
3738
"//compiler/src/iree/compiler/Codegen/LLVMCPU",
3839
"//compiler/src/iree/compiler/Codegen/Utils",

compiler/plugins/target/LLVMCPU/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ iree_cc_library(
5454
MLIRTargetLLVMIRExport
5555
MLIRTransformDialect
5656
iree::compiler::Codegen::Common
57+
iree::compiler::Codegen::Dialect::CPU::IR::IREECPUDialect
5758
iree::compiler::Codegen::Dialect::Codegen::IR::IREECodegenDialect
5859
iree::compiler::Codegen::LLVMCPU
5960
iree::compiler::Codegen::Utils

compiler/plugins/target/LLVMCPU/LLVMCPUTarget.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "compiler/plugins/target/LLVMCPU/LibraryBuilder.h"
1616
#include "compiler/plugins/target/LLVMCPU/LinkerTool.h"
1717
#include "compiler/plugins/target/LLVMCPU/StaticLibraryGenerator.h"
18+
#include "iree/compiler/Codegen/Dialect/CPU/IR/IREECPUDialect.h"
1819
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.h"
1920
#include "iree/compiler/Codegen/LLVMCPU/Passes.h"
2021
#include "iree/compiler/Codegen/LLVMCPU/Utils.h"
@@ -218,6 +219,7 @@ class LLVMCPUTargetBackend final : public TargetBackend {
218219
// TODO: make inclusion of ArmNeon conditional?
219220
// clang-format off
220221
registry.insert<IREE::Codegen::IREECodegenDialect,
222+
IREE::CPU::IREECPUDialect,
221223
IREE::LinalgExt::IREELinalgExtDialect,
222224
mlir::transform::TransformDialect,
223225
pdl::PDLDialect,

compiler/plugins/target/VMVX/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ iree_compiler_cc_library(
2323
"VMVXTarget.cpp",
2424
],
2525
deps = [
26+
"//compiler/src/iree/compiler/Codegen/Dialect/CPU/IR:IREECPUDialect",
2627
"//compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR:IREECodegenDialect",
2728
"//compiler/src/iree/compiler/Codegen/VMVX",
2829
"//compiler/src/iree/compiler/Dialect/HAL/Target",

compiler/plugins/target/VMVX/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ iree_cc_library(
2727
MLIRIR
2828
MLIRPass
2929
MLIRSupport
30+
iree::compiler::Codegen::Dialect::CPU::IR::IREECPUDialect
3031
iree::compiler::Codegen::Dialect::Codegen::IR::IREECodegenDialect
3132
iree::compiler::Codegen::VMVX
3233
iree::compiler::Dialect::HAL::Target

compiler/plugins/target/VMVX/VMVXTarget.cpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
// See https://llvm.org/LICENSE.txt for license information.
55
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66

7+
#include "iree/compiler/Codegen/Dialect/CPU/IR/IREECPUDialect.h"
78
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.h"
89
#include "iree/compiler/Codegen/VMVX/Passes.h"
910
#include "iree/compiler/Dialect/HAL/Target/Devices/LocalDevice.h"
@@ -77,9 +78,10 @@ class VMVXTargetBackend final : public TargetBackend {
7778
}
7879

7980
void getDependentDialects(DialectRegistry &registry) const override {
80-
registry.insert<IREE::Codegen::IREECodegenDialect, IREE::VM::VMDialect,
81-
IREE::VMVX::VMVXDialect,
82-
IREE::LinalgExt::IREELinalgExtDialect>();
81+
registry
82+
.insert<IREE::Codegen::IREECodegenDialect, IREE::CPU::IREECPUDialect,
83+
IREE::VM::VMDialect, IREE::VMVX::VMVXDialect,
84+
IREE::LinalgExt::IREELinalgExtDialect>();
8385
}
8486

8587
IREE::VM::TargetOptions
@@ -232,8 +234,8 @@ class VMVXInlineTargetBackend final : public TargetBackend {
232234
}
233235

234236
void getDependentDialects(DialectRegistry &registry) const override {
235-
registry
236-
.insert<IREE::Codegen::IREECodegenDialect, IREE::VMVX::VMVXDialect>();
237+
registry.insert<IREE::Codegen::IREECodegenDialect,
238+
IREE::CPU::IREECPUDialect, IREE::VMVX::VMVXDialect>();
237239
}
238240

239241
void

compiler/src/iree/compiler/Codegen/Common/CPU/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ iree_compiler_cc_library(
5656
":PassHeaders",
5757
":PassesIncGen",
5858
"//compiler/src/iree/compiler/Codegen/Common",
59+
"//compiler/src/iree/compiler/Codegen/Dialect/CPU/IR:IREECPUDialect",
5960
"//compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR:IREECodegenDialect",
6061
"//compiler/src/iree/compiler/Codegen/Dialect/Codegen/Utils",
6162
"//compiler/src/iree/compiler/Codegen/Interfaces:UKernelOpInterface",

compiler/src/iree/compiler/Codegen/Common/CPU/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ iree_cc_library(
7878
MLIRVectorTransforms
7979
iree::builtins::ukernel::exported_bits
8080
iree::compiler::Codegen::Common
81+
iree::compiler::Codegen::Dialect::CPU::IR::IREECPUDialect
8182
iree::compiler::Codegen::Dialect::Codegen::IR::IREECodegenDialect
8283
iree::compiler::Codegen::Dialect::Codegen::Utils
8384
iree::compiler::Codegen::Interfaces::UKernelOpInterface

compiler/src/iree/compiler/Codegen/Common/CPU/CPUMaterializeEncodings.cpp

Lines changed: 20 additions & 151 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
#include "iree/compiler/Codegen/Common/CPU/Passes.h"
88
#include "iree/compiler/Codegen/Common/EncodingUtils.h"
9+
#include "iree/compiler/Codegen/Dialect/CPU/IR/IREECPUDialect.h"
10+
#include "iree/compiler/Codegen/Dialect/CPU/IR/IREECPUTypes.h"
911
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.h"
1012
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.h"
1113
#include "iree/compiler/Codegen/Dialect/Codegen/Utils/Utils.h"
@@ -38,39 +40,9 @@ using IREE::Codegen::TileMxNxK;
3840
#define GEN_PASS_DEF_CPUMATERIALIZEHOSTENCODINGPASS
3941
#include "iree/compiler/Codegen/Common/CPU/Passes.h.inc"
4042

41-
// Enumerate tile sizes to choose from when no specific architecture is
42-
// targeted. For narrow-{M,N} cases, this only enumerates on narrow M. The
43-
// narrow-N cases are handled by transposition in chooseMatmulTile.
44-
static SmallVector<TileMxNxK>
45-
enumerateMatmulTilesVMVX(linalg::ContractionDimensions cDims,
46-
IREE::Encoding::EncodingAttr encoding,
47-
IREE::HAL::ExecutableTargetAttr target) {
48-
bool hasUkernelSupport = hasUkernel(target);
49-
50-
// TODO(hanchung): The ukernel path does not support 3d
51-
// codegen.query_tile_sizes op, so we disable dynamic tile shapes for
52-
// batch_matmul. Also, they are not set up for narrow M/N matmul, so it is
53-
// disabled when it is the case.
54-
if (!cDims.batch.empty() || getMatmulNarrowDim(encoding)) {
55-
hasUkernelSupport = false;
56-
}
57-
if (hasUkernelSupport) {
58-
// VMVX+ukernel uses dynamic tile shapes.
59-
return {TileMxNxK{ShapedType::kDynamic, ShapedType::kDynamic,
60-
ShapedType::kDynamic}};
61-
}
62-
63-
return {
64-
TileMxNxK{8, 8, 4}, // Some vaguely reasonable tile shape.
65-
TileMxNxK{4, 8, 4}, // Truncation of the above.
66-
TileMxNxK{2, 8, 4}, // Truncation of the above.
67-
TileMxNxK{1, 8, 4}, // Truncation of the above.
68-
};
69-
}
70-
7143
// Enumerate tile sizes to choose from on riscv32.
7244
// For narrow-{M,N} cases, this only enumerates on narrow M. The narrow-N cases
73-
// are handled by transposition in chooseMatmulTile.
45+
// are handled by transposition in IREE::Codegen::chooseMatmulTile.
7446
static SmallVector<TileMxNxK>
7547
enumerateMatmulTileRiscv32(IREE::HAL::ExecutableTargetAttr target) {
7648
if (hasUkernel(target)) {
@@ -87,7 +59,7 @@ enumerateMatmulTileRiscv32(IREE::HAL::ExecutableTargetAttr target) {
8759

8860
// Enumerate tile sizes to choose from on arm64.
8961
// For narrow-{M,N} cases, this only enumerates on narrow M. The narrow-N cases
90-
// are handled by transposition in chooseMatmulTile.
62+
// are handled by transposition in IREE::Codegen::chooseMatmulTile.
9163
static SmallVector<TileMxNxK>
9264
enumerateMatmulTileArm64(TypeRange elementTypes,
9365
IREE::HAL::ExecutableTargetAttr target) {
@@ -178,7 +150,7 @@ enumerateMatmulTileArm64(TypeRange elementTypes,
178150

179151
// Enumerate tile sizes to choose from on x86-64.
180152
// For narrow-{M,N} cases, this only enumerates on narrow M. The narrow-N cases
181-
// are handled by transposition in chooseMatmulTile.
153+
// are handled by transposition in IREE::Codegen::chooseMatmulTile.
182154
static SmallVector<TileMxNxK>
183155
enumerateMatmulTileX86_64(TypeRange elementTypes,
184156
IREE::HAL::ExecutableTargetAttr target) {
@@ -291,114 +263,6 @@ enumerateMatmulTileX86_64(TypeRange elementTypes,
291263
return {};
292264
}
293265

294-
/// Returns the best TileMxNxK from `enumeratedTiles` pool. If the
295-
/// `hostDefinedUpperBound` is not empty, the chosen tile sizes can not be
296-
/// greater than the values.
297-
/// TODO(#16933): Remove `hostDefinedUpperBound` once we can propagate such
298-
/// information to host. For now, they are defined by host.
299-
static TileMxNxK
300-
chooseMatmulTile(ArrayRef<TileMxNxK> enumeratedTiles,
301-
IREE::Encoding::MatmulNarrowDim narrowDim,
302-
ArrayRef<int64_t> hostDefinedUpperBound = {}) {
303-
assert((hostDefinedUpperBound.empty() || hostDefinedUpperBound.size() >= 3) &&
304-
"expected hostDefinedUpperBound is empty or has upper bound for {M, "
305-
"N, K}");
306-
// Handle narrow-N by transposing to reduce to narrow-M. Note: the
307-
// enumeratedTiles currently only enumerate narrow-M cases.
308-
if (narrowDim.isN()) {
309-
SmallVector<int64_t> newHostDefinedUpperBound(hostDefinedUpperBound);
310-
std::swap(newHostDefinedUpperBound[0], newHostDefinedUpperBound[1]);
311-
narrowDim.dim = IREE::Encoding::MatmulNarrowDim::Dim::M;
312-
TileMxNxK tile =
313-
chooseMatmulTile(enumeratedTiles, narrowDim, newHostDefinedUpperBound);
314-
std::swap(tile.M, tile.N);
315-
return tile;
316-
}
317-
// Handle kDynamic: currently this is only used with VMVX, where there is only
318-
// one enumerated tile and it has all three M/N/K dimensions dynamic, so for
319-
// now we only support that. Generalize that as needed when more dynamic tile
320-
// sizes are used outside of VMVX, e.g. perhaps some day with Arm SVE. Decide
321-
// how to incorporate the handling of kDynamic in the cost-model evaluation
322-
// below to decide when to prefer a dynamic vs a static tile shape.
323-
for (auto tile : enumeratedTiles) {
324-
if (ShapedType::isDynamic(tile.M) || ShapedType::isDynamic(tile.N) ||
325-
ShapedType::isDynamic(tile.K)) {
326-
assert(enumeratedTiles.size() == 1);
327-
assert(ShapedType::isDynamic(tile.M) && ShapedType::isDynamic(tile.N) &&
328-
ShapedType::isDynamic(tile.K));
329-
return tile;
330-
}
331-
}
332-
// We're going to "rate" the enumerated tiles.
333-
struct RatedTileMxNxK : TileMxNxK {
334-
RatedTileMxNxK() {}
335-
RatedTileMxNxK(TileMxNxK tile) : TileMxNxK(tile) {}
336-
// Penalize tiles that are wider in the M dimension than matmulNarrowM.
337-
int64_t paddingPenalty = 0;
338-
// Favor larger tiles, as long as they still minimize paddingPenalty.
339-
int64_t productMxNxK = 0;
340-
};
341-
SmallVector<RatedTileMxNxK> ratedTiles;
342-
ratedTiles.reserve(enumeratedTiles.size());
343-
int64_t bestPaddingPenalty = INT64_MAX;
344-
int64_t mUB = INT64_MAX;
345-
int64_t nUB = INT64_MAX;
346-
int64_t kUB = INT64_MAX;
347-
if (!hostDefinedUpperBound.empty()) {
348-
mUB = hostDefinedUpperBound[0];
349-
nUB = hostDefinedUpperBound[1];
350-
kUB = hostDefinedUpperBound[2];
351-
}
352-
for (auto tile : enumeratedTiles) {
353-
if (tile.M > mUB || tile.N > nUB || tile.K > kUB) {
354-
LLVM_DEBUG(llvm::dbgs() << "[" << DEBUG_TYPE << "]: tile (";
355-
llvm::interleaveComma(
356-
ArrayRef<int64_t>{tile.M, tile.N, tile.K}, llvm::dbgs());
357-
llvm::dbgs()
358-
<< ") is skipped because it is not valid for upper_bound (";
359-
llvm::interleaveComma(ArrayRef<int64_t>{mUB, nUB, kUB},
360-
llvm::dbgs());
361-
llvm::dbgs() << ")\n");
362-
continue;
363-
}
364-
RatedTileMxNxK ratedTile(tile);
365-
ratedTile.paddingPenalty = 0;
366-
// If we are choosing a tile for a narrow-M case, we want to minimize
367-
// padding along the M dimension.
368-
// The PowerOf2Ceil is so that we are OK with padding up to the next
369-
// power of two, we just try to avoid padding beyond that. For example,
370-
// if matmulNarrowM==7 and we have enumerated tiles with M=8,4,2,1, we
371-
// are OK with the tile that has M==8 even though it requires some padding.
372-
// Otherwise, we would be penalizing the tiles with M==8,4,2 and we would
373-
// end up selecting the vecmat tile (M==1) for that case!
374-
if (narrowDim) {
375-
ratedTile.paddingPenalty =
376-
std::max<int64_t>(tile.M - llvm::PowerOf2Ceil(narrowDim.size), 0);
377-
}
378-
ratedTile.productMxNxK = tile.M * tile.N * tile.K;
379-
ratedTiles.push_back(ratedTile);
380-
381-
LLVM_DEBUG(llvm::dbgs() << "candidate: "; llvm::interleaveComma(
382-
ArrayRef<int64_t>{tile.M, tile.N, tile.K}, llvm::dbgs());
383-
llvm::dbgs() << " penalty:" << ratedTile.paddingPenalty << "\n");
384-
385-
bestPaddingPenalty = std::min(bestPaddingPenalty, ratedTile.paddingPenalty);
386-
}
387-
RatedTileMxNxK bestRatedTile;
388-
for (auto ratedTile : ratedTiles) {
389-
// Choose only among tiles that minimize paddingPenalty. Among those,
390-
// maximize productMxNxK.
391-
if (ratedTile.paddingPenalty == bestPaddingPenalty &&
392-
bestRatedTile.productMxNxK < ratedTile.productMxNxK) {
393-
bestRatedTile = ratedTile;
394-
}
395-
}
396-
// Sanity check. This assert can only fail if there's a programming mistake
397-
// locally here.
398-
assert(bestRatedTile.paddingPenalty == bestPaddingPenalty);
399-
return bestRatedTile;
400-
}
401-
402266
static SmallVector<TileMxNxK>
403267
enumerateMatmulTileMxNxK(IREE::Encoding::EncodingAttr encoding,
404268
IREE::HAL::ExecutableTargetAttr target) {
@@ -410,9 +274,6 @@ enumerateMatmulTileMxNxK(IREE::Encoding::EncodingAttr encoding,
410274
}
411275
// Enumerate available tile shapes for the given encoding and target.
412276
SmallVector<Type> elementTypes = encoding.getElementTypesArray();
413-
if (isVMVXBackend(target)) {
414-
return enumerateMatmulTilesVMVX(*cDims, encoding, target);
415-
}
416277
if (isAArch64(target)) {
417278
return enumerateMatmulTileArm64(elementTypes, target);
418279
}
@@ -442,8 +303,8 @@ materializeEncodingForTarget(RankedTensorType tensorType,
442303
auto narrowDim = IREE::Encoding::getMatmulNarrowDim(encoding);
443304
// Choose a final matmul TileMxNxK from the above-enumarated tile shapes,
444305
// taking narrow dimensions into account.
445-
TileMxNxK chosenTileMxNxK = chooseMatmulTile(enumeratedTileMxNxK, narrowDim,
446-
encoding.getRoundDimsToArray());
306+
TileMxNxK chosenTileMxNxK = IREE::Codegen::chooseMatmulTile(
307+
enumeratedTileMxNxK, narrowDim, encoding.getRoundDimsToArray());
447308

448309
// Map the matmul TileMxNxK to an actual tile shape for the tensor at hand,
449310
// based on its operand index in the matmul.
@@ -481,9 +342,15 @@ materializeFuncOpEncodings(FunctionOpInterface funcOp,
481342
// 2. We use ukernels, and this allows writing 2x fewer narrow ukernels.
482343
// 3. Heuristics for cache-friendly dispatch tiling can get complex on CPU,
483344
// so it is nice that they have fewer narrow cases to consider.
345+
IREE::Codegen::LayoutAttrInterface layoutAttr;
346+
if (isVMVXBackend(targetAttr)) {
347+
layoutAttr = cast<IREE::Codegen::LayoutAttrInterface>(
348+
IREE::CPU::VMVXEncodingLayoutAttr::get(ctx,
349+
targetAttr.getConfiguration()));
350+
}
484351
MaterializeEncodingTypeConverter typeConverter(
485352
materializeEncodingForTarget, targetAttr, /*transposeNarrowN=*/true,
486-
/*layoutAttr=*/{});
353+
layoutAttr);
487354
MaterializeEncodingConversionTarget target(*ctx);
488355
auto materializeEncodingValueFn = getMaterializeEncodingValueFn(targetAttr);
489356
populateMaterializeEncodingIntoPackUnPackPatterns(
@@ -547,8 +414,9 @@ struct CPUMaterializeHostEncodingPass
547414
: public impl::CPUMaterializeHostEncodingPassBase<
548415
CPUMaterializeHostEncodingPass> {
549416
void getDependentDialects(DialectRegistry &registry) const override {
550-
registry.insert<arith::ArithDialect, tensor::TensorDialect,
551-
IREE::Codegen::IREECodegenDialect>();
417+
registry
418+
.insert<arith::ArithDialect, tensor::TensorDialect,
419+
IREE::Codegen::IREECodegenDialect, IREE::CPU::IREECPUDialect>();
552420
}
553421

554422
void runOnOperation() override {
@@ -607,8 +475,9 @@ struct CPUMaterializeDeviceEncodingPass
607475
: public impl::CPUMaterializeDeviceEncodingPassBase<
608476
CPUMaterializeDeviceEncodingPass> {
609477
void getDependentDialects(DialectRegistry &registry) const override {
610-
registry.insert<arith::ArithDialect, tensor::TensorDialect,
611-
IREE::Codegen::IREECodegenDialect>();
478+
registry
479+
.insert<arith::ArithDialect, tensor::TensorDialect,
480+
IREE::Codegen::IREECodegenDialect, IREE::CPU::IREECPUDialect>();
612481
}
613482

614483
void runOnOperation() override {
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# Copyright 2024 The IREE Authors
2+
#
3+
# Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
# See https://llvm.org/LICENSE.txt for license information.
5+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
7+
package(
8+
default_visibility = ["//visibility:public"],
9+
features = ["layering_check"],
10+
licenses = ["notice"], # Apache 2.0
11+
)

0 commit comments

Comments
 (0)