Skip to content

Commit fc178c2

Browse files
authored
Support lowering of mixed precision vector.contract to AMX. (#1070)
This patch support mixed precision lowering of vector.contract with bf16/i8 input and fp32/i32 output type respectively while also refactoring the code for readability/modularity. Moreover, it also prepares for other possible mixed precision operation. It was observed that the 'REQUIRES' directive was not working due to a lit config gap. This also fixes the observed issue.
1 parent 68ad15b commit fc178c2

File tree

6 files changed

+331
-154
lines changed

6 files changed

+331
-154
lines changed

lib/TPP/PassBundles/VectorToKernel.cpp

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -52,17 +52,12 @@ struct VectorToKernel : public tpp::impl::VectorToKernelBase<VectorToKernel>,
5252

5353
private:
5454
void constructPipeline() override {
55-
// TODO: Pass ordering based on target architecture starting from AMX ->
56-
// avx512 -> avx2 to subset needs to be improved by updating the `k`
57-
// tile size check for AMX lowering. With k = 1 (or vnni size) AMX fails
58-
// lowering to micro-kernels on EMR. Bf16DotProduct tests with k = 1
59-
// and those tests gets lowered by AMX pass on EMR machine.
6055
pm.addNestedPass<func::FuncOp>(createHoistVectorTransfers());
56+
if (vnni::utils::hasAMX())
57+
pm.addNestedPass<func::FuncOp>(createVectorContractToAMX());
6158
MicroKernelsOptions options;
6259
options.targetFeature = vecBundleCpuTargetFeature;
6360
pm.addNestedPass<func::FuncOp>(createMicroKernels(options));
64-
if (vnni::utils::hasAMX())
65-
pm.addNestedPass<func::FuncOp>(createVectorContractToAMX());
6661
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
6762
}
6863
};

lib/TPP/Transforms/TransformUtils.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -268,9 +268,13 @@ isContraction(linalg::LinalgOp linalgOp) {
268268
.operation(NumDpsInits(EqualsTo(1)))
269269
.operation(NumDpsInputs(EqualsTo(2)))
270270
.operation(NumAffineMaps(EqualsTo(3)))
271-
.region(MatchOne(0),
272-
WithOpChain<arith::MulFOp,
273-
arith::AddFOp>(/*captures=*/nullptr));
271+
.region(MatchOne(0), [&](Region *region, Operation *op) {
272+
return WithOpChain<KindMul, KindAdd>(/*captures=*/nullptr)(region, op) ||
273+
WithOpChain<arith::ExtFOp,
274+
arith::ExtFOp, KindMul, KindAdd>(nullptr)(region, op) ||
275+
WithOpChain<arith::ExtSIOp,
276+
arith::ExtSIOp, KindMul, KindAdd>(nullptr)(region, op);
277+
});
274278
// clang-format on
275279
if (!maybeContraction.match(linalgOp))
276280
return failure();

lib/TPP/Transforms/Utils/VNNIUtils.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,14 +57,17 @@ unsigned getVnniBlockingFactor(Type type, Operation *op) {
5757
unsigned blockingFactor = 0;
5858

5959
auto elementType = getElementTypeOrSelf(type);
60-
if (elementType.isBF16()) {
60+
if (elementType.isBF16() || elementType.isInteger(8)) {
6161
// Check if a VNNI factor hint is associated to the IR via DLTI.
6262
auto vnniValue = dlti::utils::query(op, {"CPU", "vnni"});
6363
if (succeeded(vnniValue)) {
6464
if (auto intAttr = llvm::dyn_cast<IntegerAttr>(*vnniValue))
6565
blockingFactor = intAttr.getInt();
6666
} else {
67-
blockingFactor = libxsmm_cpuid_dot_pack_factor(LIBXSMM_DATATYPE_BF16);
67+
blockingFactor =
68+
elementType.isBF16()
69+
? libxsmm_cpuid_dot_pack_factor(LIBXSMM_DATATYPE_BF16)
70+
: libxsmm_cpuid_dot_pack_factor(LIBXSMM_DATATYPE_I8);
6871
}
6972
}
7073

@@ -177,7 +180,8 @@ bool isInVnniLayout(VnniOperandRank expectedRank, ShapedType shape,
177180

178181
bool isInVnniLayout(int64_t expectedRank, ShapedType shape,
179182
std::optional<unsigned> blockingFactor) {
180-
if (shape.getRank() != expectedRank || !shape.getElementType().isBF16())
183+
if (shape.getRank() != expectedRank ||
184+
!(shape.getElementType().isBF16() || shape.getElementType().isInteger(8)))
181185
return false;
182186

183187
auto vnniDim = shape.getShape().back();

0 commit comments

Comments
 (0)