Skip to content

Commit 8151f1e

Browse files
openeuler-ci-botgitee-org
authored andcommitted
!200 [SME][matrix_type] lower matrix_type with ARM SME/SVE instructions
From: @chenzheng1030 Reviewed-by: @eastB233 Signed-off-by: @eastB233
2 parents 83fb1f3 + 15de98a commit 8151f1e

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

58 files changed

+2720
-17
lines changed

clang/include/clang/AST/Type.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4176,6 +4176,21 @@ class MatrixType : public Type, public llvm::FoldingSetNode {
41764176
(T->isRealType() && !T->isBooleanType() && !T->isEnumeralType());
41774177
}
41784178

4179+
static bool isValidTypeForSME(QualType T) {
4180+
if (!isValidElementType(T))
4181+
return false;
4182+
4183+
if (!isa<BuiltinType>(T))
4184+
return false;
4185+
4186+
// AArch64 can not do vector operations like fma/add/sub for __bf16.
4187+
if (T->isBFloat16Type())
4188+
return false;
4189+
4190+
return cast<BuiltinType>(T)->isFloatingPoint() ||
4191+
cast<BuiltinType>(T)->isInteger();
4192+
}
4193+
41794194
bool isSugared() const { return false; }
41804195
QualType desugar() const { return QualType(this, 0); }
41814196

clang/lib/CodeGen/CGBuiltin.cpp

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3976,12 +3976,18 @@ RValue CodeGenFunction::EmitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
39763976
*this, E, llvm::Intrinsic::vector_reduce_and, "rdx.and"));
39773977

39783978
case Builtin::BI__builtin_matrix_transpose: {
3979-
auto *MatrixTy = E->getArg(0)->getType()->castAs<ConstantMatrixType>();
3980-
Value *MatValue = EmitScalarExpr(E->getArg(0));
3979+
auto *MatrixValue = E->getArg(0);
3980+
auto *MatrixTy = MatrixValue->getType()->castAs<ConstantMatrixType>();
3981+
Value *MatValue = EmitScalarExpr(MatrixValue);
39813982
MatrixBuilder MB(Builder);
3982-
Value *Result = MB.CreateMatrixTranspose(MatValue, MatrixTy->getNumRows(),
3983-
MatrixTy->getNumColumns());
3984-
return RValue::get(Result);
3983+
3984+
if (!getContext().getTargetInfo().hasFeature("sme") ||
3985+
!MatrixType::isValidTypeForSME(MatrixTy->getElementType()))
3986+
return RValue::get(MB.CreateMatrixTranspose(
3987+
MatValue, MatrixTy->getNumRows(), MatrixTy->getNumColumns()));
3988+
3989+
return RValue::get(MB.CreateSMEMatrixTranspose(
3990+
MatValue,MatrixTy->getNumRows(),MatrixTy->getNumColumns()));
39853991
}
39863992

39873993
case Builtin::BI__builtin_matrix_column_major_load: {

clang/lib/CodeGen/CGExprScalar.cpp

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -764,10 +764,23 @@ class ScalarExprEmitter
764764
auto *RHSMatTy = dyn_cast<ConstantMatrixType>(
765765
BO->getRHS()->getType().getCanonicalType());
766766
CodeGenFunction::CGFPOptionsRAII FPOptsRAII(CGF, Ops.FPFeatures);
767-
if (LHSMatTy && RHSMatTy)
768-
return MB.CreateMatrixMultiply(Ops.LHS, Ops.RHS, LHSMatTy->getNumRows(),
769-
LHSMatTy->getNumColumns(),
770-
RHSMatTy->getNumColumns());
767+
if (LHSMatTy && RHSMatTy) {
768+
// Note that SME only has non-widening MOPA for float32 and float64, so
769+
// only these two types have native SME matmul operations. For other
770+
// types, SVE version is used. We hope that SVE version is better than
771+
// default NEON or scalar version.
772+
auto Ty = LHSMatTy->getElementType();
773+
if (!CGF.getContext().getTargetInfo().hasFeature("sme") ||
774+
!MatrixType::isValidTypeForSME(Ty))
775+
return MB.CreateMatrixMultiply(
776+
Ops.LHS, Ops.RHS, LHSMatTy->getNumRows(),
777+
LHSMatTy->getNumColumns(), RHSMatTy->getNumColumns());
778+
assert(isa<BuiltinType>(Ty) && "SME types should be BuiltinType.");
779+
return MB.CreateSMEMatrixMultiply(
780+
Ops.LHS, Ops.RHS, LHSMatTy->getNumRows(), LHSMatTy->getNumColumns(),
781+
RHSMatTy->getNumColumns(),
782+
cast<BuiltinType>(Ty)->isSignedInteger());
783+
}
771784
return MB.CreateScalarMultiply(Ops.LHS, Ops.RHS);
772785
}
773786

@@ -4170,7 +4183,16 @@ Value *ScalarExprEmitter::EmitAdd(const BinOpInfo &op) {
41704183
if (op.Ty->isConstantMatrixType()) {
41714184
llvm::MatrixBuilder MB(Builder);
41724185
CodeGenFunction::CGFPOptionsRAII FPOptsRAII(CGF, op.FPFeatures);
4173-
return MB.CreateAdd(op.LHS, op.RHS);
4186+
4187+
auto *MatTy = cast<ConstantMatrixType>(op.E->getType().getCanonicalType());
4188+
auto Ty = MatTy->getElementType();
4189+
if (!CGF.getContext().getTargetInfo().hasFeature("sme") ||
4190+
!MatrixType::isValidTypeForSME(Ty))
4191+
return MB.CreateAdd(op.LHS, op.RHS);
4192+
assert(isa<BuiltinType>(Ty) && "SME types should be BuiltinType.");
4193+
return MB.CreateSMEMatrixBinOp(
4194+
op.LHS, op.RHS, MatTy->getNumRows(), MatTy->getNumColumns(),
4195+
cast<BuiltinType>(Ty)->isSignedInteger(), "add");
41744196
}
41754197

41764198
if (op.Ty->isUnsignedIntegerType() &&
@@ -4326,7 +4348,16 @@ Value *ScalarExprEmitter::EmitSub(const BinOpInfo &op) {
43264348
if (op.Ty->isConstantMatrixType()) {
43274349
llvm::MatrixBuilder MB(Builder);
43284350
CodeGenFunction::CGFPOptionsRAII FPOptsRAII(CGF, op.FPFeatures);
4329-
return MB.CreateSub(op.LHS, op.RHS);
4351+
auto *MatTy =
4352+
cast<ConstantMatrixType>(op.E->getType().getCanonicalType());
4353+
auto Ty = MatTy->getElementType();
4354+
if (!CGF.getContext().getTargetInfo().hasFeature("sme") ||
4355+
!MatrixType::isValidTypeForSME(Ty))
4356+
return MB.CreateSub(op.LHS, op.RHS);
4357+
assert(isa<BuiltinType>(Ty) && "SME types should be BuiltinType.");
4358+
return MB.CreateSMEMatrixBinOp(
4359+
op.LHS, op.RHS, MatTy->getNumRows(), MatTy->getNumColumns(),
4360+
cast<BuiltinType>(Ty)->isSignedInteger(), "sub");
43304361
}
43314362

43324363
if (op.Ty->isUnsignedIntegerType() &&

clang/lib/CodeGen/Targets/AArch64.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,10 @@ AArch64ABIInfo::classifyArgumentType(QualType Ty, bool IsVariadic,
294294
if (isIllegalVectorType(Ty))
295295
return coerceIllegalVector(Ty);
296296

297+
// Always pass the matrix type via memory.
298+
if (Ty->isMatrixType())
299+
return getNaturalAlignIndirect(Ty, false);
300+
297301
if (!isAggregateTypeForABI(Ty)) {
298302
// Treat an enum type as its underlying type.
299303
if (const EnumType *EnumTy = Ty->getAs<EnumType>())
@@ -393,6 +397,10 @@ ABIArgInfo AArch64ABIInfo::classifyReturnType(QualType RetTy,
393397
return coerceIllegalVector(RetTy);
394398
}
395399

400+
// Always return the matrix type via memory.
401+
if (RetTy->isMatrixType())
402+
return getNaturalAlignIndirect(RetTy);
403+
396404
// Large vector types should be returned via memory.
397405
if (RetTy->isVectorType() && getContext().getTypeSize(RetTy) > 128)
398406
return getNaturalAlignIndirect(RetTy);

clang/lib/Driver/ToolChain.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1112,6 +1112,10 @@ ToolChain::RuntimeLibType ToolChain::GetRuntimeLibType(
11121112
runtimeLibType = GetDefaultRuntimeLibType();
11131113
}
11141114

1115+
const llvm::Triple::ArchType Arch = getArch();
1116+
if (Arch == llvm::Triple::aarch64 && Args.hasArg(options::OPT_fenable_matrix))
1117+
runtimeLibType = ToolChain::RLT_CompilerRT;
1118+
11151119
return *runtimeLibType;
11161120
}
11171121

0 commit comments

Comments
 (0)