Skip to content

Commit 0b76033

Browse files
authored
[Backport to 17] CooperativeMatrix translation of load/store and arithmetic instructions (#3624)
Backport of PRs: #2117 #2156 #2165 #2166 to llvm_release_170 authored by @vmaksimo
1 parent f628677 commit 0b76033

File tree

4 files changed

+280
-4
lines changed

4 files changed

+280
-4
lines changed

lib/SPIRV/SPIRVReader.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1276,6 +1276,9 @@ static void applyFPFastMathModeDecorations(const SPIRVValue *BV,
12761276
Value *SPIRVToLLVM::transShiftLogicalBitwiseInst(SPIRVValue *BV, BasicBlock *BB,
12771277
Function *F) {
12781278
SPIRVBinary *BBN = static_cast<SPIRVBinary *>(BV);
1279+
if (BV->getType()->isTypeCooperativeMatrixKHR()) {
1280+
return mapValue(BV, transSPIRVBuiltinFromInst(BBN, BB));
1281+
}
12791282
Instruction::BinaryOps BO;
12801283
auto OP = BBN->getOpCode();
12811284
if (isLogicalOpCode(OP))
@@ -2450,6 +2453,10 @@ Value *SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *BV, Function *F,
24502453
auto AC = static_cast<SPIRVAccessChainBase *>(BV);
24512454
auto Base = transValue(AC->getBase(), F, BB);
24522455
SPIRVType *BaseSPVTy = AC->getBase()->getType();
2456+
if (BaseSPVTy->isTypePointer() &&
2457+
BaseSPVTy->getPointerElementType()->isTypeCooperativeMatrixKHR()) {
2458+
return mapValue(BV, transSPIRVBuiltinFromInst(AC, BB));
2459+
}
24532460
Type *BaseTy =
24542461
BaseSPVTy->isTypeVector()
24552462
? transType(
@@ -2784,6 +2791,9 @@ Value *SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *BV, Function *F,
27842791
Builder.SetInsertPoint(BB);
27852792
}
27862793
SPIRVUnary *BC = static_cast<SPIRVUnary *>(BV);
2794+
if (BV->getType()->isTypeCooperativeMatrixKHR()) {
2795+
return mapValue(BV, transSPIRVBuiltinFromInst(BC, BB));
2796+
}
27872797
auto Neg =
27882798
Builder.CreateNeg(transValue(BC->getOperand(0), F, BB), BV->getName());
27892799
if (auto *NegInst = dyn_cast<Instruction>(Neg)) {
@@ -2836,6 +2846,9 @@ Value *SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *BV, Function *F,
28362846

28372847
case OpFNegate: {
28382848
SPIRVUnary *BC = static_cast<SPIRVUnary *>(BV);
2849+
if (BV->getType()->isTypeCooperativeMatrixKHR()) {
2850+
return mapValue(BV, transSPIRVBuiltinFromInst(BC, BB));
2851+
}
28392852
auto Neg = UnaryOperator::CreateFNeg(transValue(BC->getOperand(0), F, BB),
28402853
BV->getName(), BB);
28412854
applyFPFastMathModeDecorations(BV, Neg);

lib/SPIRV/libSPIRV/SPIRVInstruction.h

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -649,6 +649,10 @@ class SPIRVBinary : public SPIRVInstTemplateBase {
649649
assert(getValueType(Op1)->getVectorComponentCount() ==
650650
getValueType(Op2)->getVectorComponentCount() &&
651651
"Inconsistent Vector component width");
652+
} else if (getValueType(Op1)->isTypeCooperativeMatrixKHR()) {
653+
Op1Ty = getValueType(Op1)->getVectorComponentType();
654+
Op2Ty = getValueType(Op2)->getVectorComponentType();
655+
assert(Op1Ty == Op2Ty && "Inconsistent Cooperative matrix types");
652656
} else {
653657
Op1Ty = getValueType(Op1);
654658
Op2Ty = getValueType(Op2);
@@ -1573,10 +1577,13 @@ class SPIRVUnary : public SPIRVInstTemplateBase {
15731577
return;
15741578
if (isGenericNegateOpCode(OpCode)) {
15751579
SPIRVType *ResTy =
1576-
Type->isTypeVector() ? Type->getVectorComponentType() : Type;
1577-
SPIRVType *OpTy = Type->isTypeVector()
1578-
? getValueType(Op)->getVectorComponentType()
1579-
: getValueType(Op);
1580+
Type->isTypeVector() || Type->isTypeCooperativeMatrixKHR()
1581+
? Type->getVectorComponentType()
1582+
: Type;
1583+
SPIRVType *OpTy =
1584+
Type->isTypeVector() || Type->isTypeCooperativeMatrixKHR()
1585+
? getValueType(Op)->getVectorComponentType()
1586+
: getValueType(Op);
15801587

15811588
(void)ResTy;
15821589
(void)OpTy;
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
; RUN: llvm-as < %s -o %t.bc
2+
; RUN: llvm-spirv %t.bc --spirv-ext=+SPV_KHR_cooperative_matrix -o %t.spv
3+
; RUN: llvm-spirv %t.spv -to-text -o %t.spt
4+
; RUN: FileCheck < %t.spt %s --check-prefix=CHECK-SPIRV
5+
6+
; RUN: llvm-spirv -r %t.spv -o %t.rev.bc
7+
; RUN: llvm-dis < %t.rev.bc | FileCheck %s --check-prefix=CHECK-LLVM
8+
9+
; CHECK-SPIRV: TypeInt [[#TypeInt:]] 32 0
10+
; CHECK-SPIRV-DAG: Constant [[#TypeInt]] [[#Const0:]] 0
11+
; CHECK-SPIRV-DAG: Constant [[#TypeInt]] [[#Const1:]] 1 {{$}}
12+
; CHECK-SPIRV-DAG: Constant [[#TypeInt]] [[#Const3:]] 3
13+
; CHECK-SPIRV-DAG: Constant [[#TypeInt]] [[#Const12:]] 12
14+
; CHECK-SPIRV-DAG: Constant [[#TypeInt]] [[#Const42:]] 42
15+
16+
; CHECK-SPIRV: TypeCooperativeMatrixKHR [[#TypeMatrix:]] [[#TypeInt]] [[#Const3]] [[#Const12]] [[#Const12]] [[#Const0]]
17+
; CHECK-SPIRV: TypePointer [[#TypeMatrixPtr:]] 7 [[#TypeMatrix]]
18+
; CHECK-SPIRV: TypePointer [[#TypeIntPtr:]] 7 [[#TypeInt]]
19+
20+
; CHECK-SPIRV: Variable [[#TypeMatrixPtr]] [[#VarMatrixPtr:]] 7
21+
; CHECK-SPIRV: CompositeConstruct [[#TypeMatrix]] [[#Composite:]] [[#Const0]]
22+
; CHECK-SPIRV: Store [[#VarMatrixPtr]] [[#Composite]]
23+
; CHECK-SPIRV: AccessChain [[#TypeIntPtr]] [[#Res:]] [[#VarMatrixPtr]] [[#Const1]]
24+
; CHECK-SPIRV: Store [[#Res]] [[#Const42]]
25+
26+
; CHECK-LLVM: %0 = alloca target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 0)
27+
; CHECK-LLVM: %Obj = call spir_func target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 0) @_Z26__spirv_CompositeConstructi(i32 0)
28+
; CHECK-LLVM: store target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 0) %Obj, ptr %0
29+
; CHECK-LLVM: %call = call spir_func ptr @_Z19__spirv_AccessChainPPU3AS144__spirv_CooperativeMatrixKHR__uint_3_12_12_0i(ptr %0, i32 1)
30+
; CHECK-LLVM: store i32 42, ptr %call
31+
32+
target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
33+
target triple = "spir64-unknown-unknown"
34+
35+
; Function Attrs: mustprogress uwtable
36+
define dso_local void @_Z3fooi(i32 noundef %idx) local_unnamed_addr #0 {
37+
entry:
38+
%0 = alloca target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 0), align 8
39+
%Obj = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 0) @_Z26__spirv_CompositeConstruct(i32 noundef 0) #4
40+
store target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 0) %Obj, ptr %0, align 8
41+
%call = call noundef ptr @_Z19__spirv_AccessChainP6Matrixii(ptr %0, i32 noundef 1)
42+
call void @_Z13__spirv_StorePii(ptr noundef %call, i32 noundef 42)
43+
ret void
44+
}
45+
46+
declare dso_local spir_func noundef target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 0) @_Z26__spirv_CompositeConstruct(i32 noundef) local_unnamed_addr #2
47+
48+
declare noundef ptr @_Z19__spirv_AccessChainP6Matrixii(ptr noundef, i32 noundef) local_unnamed_addr #2
49+
50+
declare void @_Z13__spirv_StorePii(ptr noundef, i32 noundef) local_unnamed_addr #2
51+
52+
attributes #0 = { mustprogress uwtable "min-legal-vector-width"="0" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+cx8,+fxsr,+mmx,+sse,+sse2,+x87" "tune-cpu"="generic" }
53+
attributes #1 = { mustprogress nocallback nofree nosync nounwind willreturn memory(argmem: readwrite) }
54+
attributes #2 = { "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+cx8,+fxsr,+mmx,+sse,+sse2,+x87" "tune-cpu"="generic" }
55+
attributes #3 = { nounwind }
56+
57+
!llvm.module.flags = !{!0, !1, !2, !3, !4}
58+
!llvm.ident = !{!5}
59+
60+
!0 = !{i32 7, !"Dwarf Version", i32 4}
61+
!1 = !{i32 1, !"wchar_size", i32 4}
62+
!2 = !{i32 8, !"PIC Level", i32 2}
63+
!3 = !{i32 7, !"PIE Level", i32 2}
64+
!4 = !{i32 7, !"uwtable", i32 2}
65+
!5 = !{!"clang version 16.0.0 (https://github.com/llvm/llvm-project.git 08d094a0e457360ad8b94b017d2dc277e697ca76)"}

0 commit comments

Comments
 (0)