Skip to content

Commit 42b36c5

Browse files
FP8 bitcast
1 parent e6da3ea commit 42b36c5

File tree

6 files changed

+372
-0
lines changed

6 files changed

+372
-0
lines changed

clang/include/clang/Basic/DiagnosticSemaKinds.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3415,6 +3415,8 @@ def err_typecheck_vector_not_convertable : Error<
34153415
"cannot convert between vector values of different size (%0 and %1)">;
34163416
def err_typecheck_vector_not_convertable_non_scalar : Error<
34173417
"cannot convert between vector and non-scalar values (%0 and %1)">;
3418+
def err_typecheck_vector_not_convertable_non_vector : Error<
3419+
"cannot convert between vector type %0 and non-vector type %1">;
34183420
def err_typecheck_vector_lengths_not_equal : Error<
34193421
"vector operands do not have the same number of elements (%0 and %1)">;
34203422
def warn_typecheck_vector_element_sizes_not_equal : Warning<

clang/include/clang/Sema/Sema.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7429,6 +7429,11 @@ class Sema final : public SemaBase {
74297429
/// the perspective of SVE bitcasts.
74307430
bool isValidSveBitcast(QualType srcType, QualType destType);
74317431

7432+
/// Check for bitcast beween a regular vector type and builtin Neon vector
7433+
/// type.
7434+
bool isValidNeonVectorBuiltinTypeBitcast(SourceRange OpRange, QualType SrcTy,
7435+
QualType DstTy);
7436+
74327437
/// Are the two types matrix types and do they have the same dimensions i.e.
74337438
/// do they have the same number of rows and the same number of columns?
74347439
bool areMatrixTypesOfTheSameDimension(QualType srcTy, QualType destTy);

clang/lib/Sema/SemaCast.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2386,6 +2386,16 @@ static TryCastResult TryReinterpretCast(Sema &Self, ExprResult &SrcExpr,
23862386
return TC_Success;
23872387
}
23882388

2389+
// Allow bitcasting between a regular vector type or a scalar, and a builtin
2390+
// Neon vector type.
2391+
if ((SrcType->isNeonVectorBuiltinType() ||
2392+
DestType->isNeonVectorBuiltinType())) {
2393+
if (!Self.isValidNeonVectorBuiltinTypeBitcast(OpRange, SrcType, DestType))
2394+
return TC_Failed;
2395+
Kind = CK_BitCast;
2396+
return TC_Success;
2397+
}
2398+
23892399
// Allow reinterpret_casts between vectors of the same size and
23902400
// between vectors and integers of the same size.
23912401
bool destIsVector = DestType->isVectorType();
@@ -3009,6 +3019,18 @@ void CastOperation::CheckCStyleCast() {
30093019
return;
30103020
}
30113021

3022+
// Allow bitcasting between a regular vector type and a builtin Neon vector
3023+
// type.
3024+
if (SrcType->isNeonVectorBuiltinType() ||
3025+
DestType->isNeonVectorBuiltinType()) {
3026+
if (!Self.isValidNeonVectorBuiltinTypeBitcast(OpRange, SrcType, DestType)) {
3027+
SrcExpr = ExprError();
3028+
return;
3029+
}
3030+
Kind = CK_BitCast;
3031+
return;
3032+
}
3033+
30123034
// Allow bitcasting between compatible RVV vector types.
30133035
if ((SrcType->isVectorType() || DestType->isVectorType()) &&
30143036
Self.RISCV().isValidRVVBitcast(SrcType, DestType)) {

clang/lib/Sema/SemaExpr.cpp

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7530,6 +7530,54 @@ bool Sema::isValidSveBitcast(QualType srcTy, QualType destTy) {
75307530
ValidScalableConversion(destTy, srcTy);
75317531
}
75327532

7533+
bool Sema::isValidNeonVectorBuiltinTypeBitcast(SourceRange OpRange,
7534+
QualType SrcTy, QualType DstTy) {
7535+
assert(SrcTy->isNeonVectorBuiltinType() || DstTy->isNeonVectorBuiltinType());
7536+
7537+
auto checkCast = [&](QualType BT, QualType OT) -> unsigned {
7538+
if (OT->isNeonVectorBuiltinType()) {
7539+
if (BT.getCanonicalType() != OT.getCanonicalType())
7540+
return diag::err_invalid_conversion_between_vectors;
7541+
return 0;
7542+
}
7543+
7544+
if (!OT->isVectorType()) {
7545+
if (!OT->isScalarType())
7546+
return diag::err_typecheck_vector_not_convertable_non_scalar;
7547+
7548+
if (!OT->isIntegralType(Context))
7549+
return diag::err_typecheck_vector_not_convertable_non_vector;
7550+
}
7551+
7552+
uint64_t OTLen;
7553+
QualType OTEltTy;
7554+
if (const VectorType *VT = OT->getAs<VectorType>()) {
7555+
OTLen = VT->getNumElements();
7556+
OTEltTy = VT->getElementType();
7557+
} else {
7558+
OTLen = 1;
7559+
OTEltTy = OT;
7560+
}
7561+
7562+
uint64_t BTLen =
7563+
BT->getAs<BuiltinType>()->getKind() == BuiltinType::MFloat8x8 ? 8u
7564+
: 16u;
7565+
if (BTLen * 8u != OTLen * Context.getTypeSize(OTEltTy))
7566+
return diag::err_invalid_conversion_between_vectors;
7567+
return 0;
7568+
};
7569+
7570+
QualType BT = SrcTy;
7571+
QualType OT = DstTy;
7572+
if (!SrcTy->isNeonVectorBuiltinType())
7573+
std::swap(BT, OT);
7574+
7575+
if (unsigned msg = checkCast(BT, OT))
7576+
return Diag(OpRange.getBegin(), msg) << BT << OT << OpRange;
7577+
7578+
return true;
7579+
}
7580+
75337581
bool Sema::areMatrixTypesOfTheSameDimension(QualType srcTy, QualType destTy) {
75347582
if (!destTy->isMatrixType() || !srcTy->isMatrixType())
75357583
return false;
Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
// NOTE: Assertions have been autogenerated by utils/update_cc_test_checks.py UTC_ARGS: --version 5
2+
// RUN: %clang_cc1 -triple aarch64-none-linux-gnu -target-feature +neon -disable-O0-optnone -Werror -Wall -emit-llvm -o - %s | opt -S -p mem2reg | FileCheck %s
3+
// RUN: %clang_cc1 -x c++ -triple aarch64-none-linux-gnu -target-feature +neon -disable-O0-optnone -Werror -Wall -emit-llvm -o - %s | opt -S -p mem2reg | FileCheck %s -check-prefix CHECK-CXX
4+
5+
// RUN: %clang_cc1 -triple aarch64-none-linux-gnu -target-feature +neon -disable-O0-optnone -Werror -Wall -S -o /dev/null %s
6+
7+
// REQUIRES: aarch64-registered-target
8+
9+
#include <arm_neon.h>
10+
11+
// Bitcast between FP8 Neon vectors
12+
// CHECK-LABEL: define dso_local <8 x i8> @test_f8_f8(
13+
// CHECK-SAME: <8 x i8> [[X:%.*]]) #[[ATTR0:[0-9]+]] {
14+
// CHECK-NEXT: [[ENTRY:.*:]]
15+
// CHECK-NEXT: ret <8 x i8> [[X]]
16+
//
17+
// CHECK-CXX-LABEL: define dso_local <8 x i8> @_Z10test_f8_f8u13__MFloat8x8_t(
18+
// CHECK-CXX-SAME: <8 x i8> [[X:%.*]]) #[[ATTR0:[0-9]+]] {
19+
// CHECK-CXX-NEXT: [[ENTRY:.*:]]
20+
// CHECK-CXX-NEXT: ret <8 x i8> [[X]]
21+
//
22+
mfloat8x8_t test_f8_f8(mfloat8x8_t x) {
23+
return (mfloat8x8_t) x;
24+
}
25+
26+
// CHECK-LABEL: define dso_local <16 x i8> @testq_f8_f8(
27+
// CHECK-SAME: <16 x i8> [[X:%.*]]) #[[ATTR0]] {
28+
// CHECK-NEXT: [[ENTRY:.*:]]
29+
// CHECK-NEXT: ret <16 x i8> [[X]]
30+
//
31+
// CHECK-CXX-LABEL: define dso_local <16 x i8> @_Z11testq_f8_f8u14__MFloat8x16_t(
32+
// CHECK-CXX-SAME: <16 x i8> [[X:%.*]]) #[[ATTR0]] {
33+
// CHECK-CXX-NEXT: [[ENTRY:.*:]]
34+
// CHECK-CXX-NEXT: ret <16 x i8> [[X]]
35+
//
36+
mfloat8x16_t testq_f8_f8(mfloat8x16_t x) {
37+
return (mfloat8x16_t) x;
38+
}
39+
40+
// Bitcast between FP8 and int8 Neon vectors
41+
// CHECK-LABEL: define dso_local <8 x i8> @test_f8_s8(
42+
// CHECK-SAME: <8 x i8> noundef [[X:%.*]]) #[[ATTR0]] {
43+
// CHECK-NEXT: [[ENTRY:.*:]]
44+
// CHECK-NEXT: ret <8 x i8> [[X]]
45+
//
46+
// CHECK-CXX-LABEL: define dso_local <8 x i8> @_Z10test_f8_s810__Int8x8_t(
47+
// CHECK-CXX-SAME: <8 x i8> noundef [[X:%.*]]) #[[ATTR0]] {
48+
// CHECK-CXX-NEXT: [[ENTRY:.*:]]
49+
// CHECK-CXX-NEXT: ret <8 x i8> [[X]]
50+
//
51+
mfloat8x8_t test_f8_s8(int8x8_t x) {
52+
return (mfloat8x8_t) x;
53+
}
54+
55+
// CHECK-LABEL: define dso_local <8 x i8> @test_s8_f8(
56+
// CHECK-SAME: <8 x i8> [[X:%.*]]) #[[ATTR0]] {
57+
// CHECK-NEXT: [[ENTRY:.*:]]
58+
// CHECK-NEXT: ret <8 x i8> [[X]]
59+
//
60+
// CHECK-CXX-LABEL: define dso_local noundef <8 x i8> @_Z10test_s8_f8u13__MFloat8x8_t(
61+
// CHECK-CXX-SAME: <8 x i8> [[X:%.*]]) #[[ATTR0]] {
62+
// CHECK-CXX-NEXT: [[ENTRY:.*:]]
63+
// CHECK-CXX-NEXT: ret <8 x i8> [[X]]
64+
//
65+
int8x8_t test_s8_f8(mfloat8x8_t x) {
66+
return (int8x8_t) x;
67+
}
68+
69+
// CHECK-LABEL: define dso_local <16 x i8> @testq_f8_s8(
70+
// CHECK-SAME: <16 x i8> noundef [[X:%.*]]) #[[ATTR0]] {
71+
// CHECK-NEXT: [[ENTRY:.*:]]
72+
// CHECK-NEXT: ret <16 x i8> [[X]]
73+
//
74+
// CHECK-CXX-LABEL: define dso_local <16 x i8> @_Z11testq_f8_s811__Int8x16_t(
75+
// CHECK-CXX-SAME: <16 x i8> noundef [[X:%.*]]) #[[ATTR0]] {
76+
// CHECK-CXX-NEXT: [[ENTRY:.*:]]
77+
// CHECK-CXX-NEXT: ret <16 x i8> [[X]]
78+
//
79+
mfloat8x16_t testq_f8_s8(int8x16_t x) {
80+
return (mfloat8x16_t) x;
81+
}
82+
83+
// CHECK-LABEL: define dso_local <16 x i8> @testq_s8_f8(
84+
// CHECK-SAME: <16 x i8> [[X:%.*]]) #[[ATTR0]] {
85+
// CHECK-NEXT: [[ENTRY:.*:]]
86+
// CHECK-NEXT: ret <16 x i8> [[X]]
87+
//
88+
// CHECK-CXX-LABEL: define dso_local noundef <16 x i8> @_Z11testq_s8_f8u14__MFloat8x16_t(
89+
// CHECK-CXX-SAME: <16 x i8> [[X:%.*]]) #[[ATTR0]] {
90+
// CHECK-CXX-NEXT: [[ENTRY:.*:]]
91+
// CHECK-CXX-NEXT: ret <16 x i8> [[X]]
92+
//
93+
int8x16_t testq_s8_f8(mfloat8x16_t x) {
94+
return (int8x16_t) x;
95+
}
96+
97+
// Bitcast between FP8 and float32 Neon vectors
98+
// CHECK-LABEL: define dso_local <8 x i8> @test_f8_f32(
99+
// CHECK-SAME: <2 x float> noundef [[X:%.*]]) #[[ATTR0]] {
100+
// CHECK-NEXT: [[ENTRY:.*:]]
101+
// CHECK-NEXT: [[TMP0:%.*]] = bitcast <2 x float> [[X]] to <8 x i8>
102+
// CHECK-NEXT: ret <8 x i8> [[TMP0]]
103+
//
104+
// CHECK-CXX-LABEL: define dso_local <8 x i8> @_Z11test_f8_f3213__Float32x2_t(
105+
// CHECK-CXX-SAME: <2 x float> noundef [[X:%.*]]) #[[ATTR0]] {
106+
// CHECK-CXX-NEXT: [[ENTRY:.*:]]
107+
// CHECK-CXX-NEXT: [[TMP0:%.*]] = bitcast <2 x float> [[X]] to <8 x i8>
108+
// CHECK-CXX-NEXT: ret <8 x i8> [[TMP0]]
109+
//
110+
mfloat8x8_t test_f8_f32(float32x2_t x) {
111+
return (mfloat8x8_t) x;
112+
}
113+
114+
// CHECK-LABEL: define dso_local <2 x float> @test_f32_f8(
115+
// CHECK-SAME: <8 x i8> [[X:%.*]]) #[[ATTR0]] {
116+
// CHECK-NEXT: [[ENTRY:.*:]]
117+
// CHECK-NEXT: [[TMP0:%.*]] = bitcast <8 x i8> [[X]] to <2 x float>
118+
// CHECK-NEXT: ret <2 x float> [[TMP0]]
119+
//
120+
// CHECK-CXX-LABEL: define dso_local noundef <2 x float> @_Z11test_f32_f8u13__MFloat8x8_t(
121+
// CHECK-CXX-SAME: <8 x i8> [[X:%.*]]) #[[ATTR0]] {
122+
// CHECK-CXX-NEXT: [[ENTRY:.*:]]
123+
// CHECK-CXX-NEXT: [[TMP0:%.*]] = bitcast <8 x i8> [[X]] to <2 x float>
124+
// CHECK-CXX-NEXT: ret <2 x float> [[TMP0]]
125+
//
126+
float32x2_t test_f32_f8(mfloat8x8_t x) {
127+
return (float32x2_t) x;
128+
}
129+
130+
// CHECK-LABEL: define dso_local <16 x i8> @testq_f8_f32(
131+
// CHECK-SAME: <4 x float> noundef [[X:%.*]]) #[[ATTR0]] {
132+
// CHECK-NEXT: [[ENTRY:.*:]]
133+
// CHECK-NEXT: [[TMP0:%.*]] = bitcast <4 x float> [[X]] to <16 x i8>
134+
// CHECK-NEXT: ret <16 x i8> [[TMP0]]
135+
//
136+
// CHECK-CXX-LABEL: define dso_local <16 x i8> @_Z12testq_f8_f3213__Float32x4_t(
137+
// CHECK-CXX-SAME: <4 x float> noundef [[X:%.*]]) #[[ATTR0]] {
138+
// CHECK-CXX-NEXT: [[ENTRY:.*:]]
139+
// CHECK-CXX-NEXT: [[TMP0:%.*]] = bitcast <4 x float> [[X]] to <16 x i8>
140+
// CHECK-CXX-NEXT: ret <16 x i8> [[TMP0]]
141+
//
142+
mfloat8x16_t testq_f8_f32(float32x4_t x) {
143+
return (mfloat8x16_t) x;
144+
}
145+
146+
// CHECK-LABEL: define dso_local <4 x float> @testq_f32_f8(
147+
// CHECK-SAME: <16 x i8> [[X:%.*]]) #[[ATTR0]] {
148+
// CHECK-NEXT: [[ENTRY:.*:]]
149+
// CHECK-NEXT: [[TMP0:%.*]] = bitcast <16 x i8> [[X]] to <4 x float>
150+
// CHECK-NEXT: ret <4 x float> [[TMP0]]
151+
//
152+
// CHECK-CXX-LABEL: define dso_local noundef <4 x float> @_Z12testq_f32_f8u14__MFloat8x16_t(
153+
// CHECK-CXX-SAME: <16 x i8> [[X:%.*]]) #[[ATTR0]] {
154+
// CHECK-CXX-NEXT: [[ENTRY:.*:]]
155+
// CHECK-CXX-NEXT: [[TMP0:%.*]] = bitcast <16 x i8> [[X]] to <4 x float>
156+
// CHECK-CXX-NEXT: ret <4 x float> [[TMP0]]
157+
//
158+
float32x4_t testq_f32_f8(mfloat8x16_t x) {
159+
return (float32x4_t) x;
160+
}
161+
162+
// Bitcast between FP8 and poly128_t (which is integral)
163+
// CHECK-LABEL: define dso_local <16 x i8> @testq_f8_p128(
164+
// CHECK-SAME: i128 noundef [[X:%.*]]) #[[ATTR0]] {
165+
// CHECK-NEXT: [[ENTRY:.*:]]
166+
// CHECK-NEXT: [[TMP0:%.*]] = bitcast i128 [[X]] to <16 x i8>
167+
// CHECK-NEXT: ret <16 x i8> [[TMP0]]
168+
//
169+
// CHECK-CXX-LABEL: define dso_local <16 x i8> @_Z13testq_f8_p128o(
170+
// CHECK-CXX-SAME: i128 noundef [[X:%.*]]) #[[ATTR0]] {
171+
// CHECK-CXX-NEXT: [[ENTRY:.*:]]
172+
// CHECK-CXX-NEXT: [[TMP0:%.*]] = bitcast i128 [[X]] to <16 x i8>
173+
// CHECK-CXX-NEXT: ret <16 x i8> [[TMP0]]
174+
//
175+
mfloat8x16_t testq_f8_p128(poly128_t x) {
176+
return (mfloat8x16_t) x;
177+
}
178+
179+
// CHECK-LABEL: define dso_local i128 @testq_p128_f8(
180+
// CHECK-SAME: <16 x i8> [[X:%.*]]) #[[ATTR0]] {
181+
// CHECK-NEXT: [[ENTRY:.*:]]
182+
// CHECK-NEXT: [[TMP0:%.*]] = bitcast <16 x i8> [[X]] to i128
183+
// CHECK-NEXT: ret i128 [[TMP0]]
184+
//
185+
// CHECK-CXX-LABEL: define dso_local noundef i128 @_Z13testq_p128_f8u14__MFloat8x16_t(
186+
// CHECK-CXX-SAME: <16 x i8> [[X:%.*]]) #[[ATTR0]] {
187+
// CHECK-CXX-NEXT: [[ENTRY:.*:]]
188+
// CHECK-CXX-NEXT: [[TMP0:%.*]] = bitcast <16 x i8> [[X]] to i128
189+
// CHECK-CXX-NEXT: ret i128 [[TMP0]]
190+
//
191+
poly128_t testq_p128_f8(mfloat8x16_t x) {
192+
return (poly128_t) x;
193+
}

0 commit comments

Comments
 (0)