Skip to content

Commit 93994ed

Browse files
committed
[NVPTX] Resolved merge conflicts + updated check for PTX version
1 parent 7fb0570 commit 93994ed

File tree

4 files changed

+29
-12
lines changed

4 files changed

+29
-12
lines changed

llvm/include/llvm/IR/IntrinsicsNVVM.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,8 @@ class WMMA_REGS<string Geom, string Frag, string PtxEltType, bit IsSparse = fals
178178
string gft = Geom#":"#Frag#":"#ptx_elt_type;
179179
string gf = Geom#":"#Frag;
180180
string ft = frag#":"#ptx_elt_type;
181-
list<LLVMType> regs = !if(!eq(IsSparse, true),
181+
bit isSparse = IsSparse;
182+
list<LLVMType> regs = !if(!eq(isSparse, true),
182183
!cond(
183184
// mma sparse ops use other fragments for some arguments
184185
!eq(gft,"m16n8k16:a:bf16") : !listsplat(llvm_i32_ty, 2),

llvm/lib/Target/NVPTX/NVPTXIntrinsics.td

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4619,7 +4619,8 @@ def INT_PTX_SREG_WARPSIZE :
46194619
// the fields commonly used to implement specific PTX instruction -- register
46204620
// types and names, constraints, parts of assembly, etc.
46214621
class WMMA_REGINFO<WMMA_REGS r, string op, string metadata = "", string kind = "">
4622-
: WMMA_REGS<r.geom, r.frag, r.ptx_elt_type, !eq(op, "mma.sp")> {
4622+
: WMMA_REGS<r.geom, r.frag, r.ptx_elt_type,
4623+
!or(!eq(op, "mma.sp"), !eq(op, "mma.sp.block_scale"))> {
46234624
// NVPTX register types used to carry fragment data.
46244625
NVPTXRegClass regclass = !cond(
46254626
!eq(ptx_elt_type, "e4m3") : B32,
@@ -4659,6 +4660,9 @@ class WMMA_REGINFO<WMMA_REGS r, string op, string metadata = "", string kind = "
46594660
// longer the case, we can concat all per-fragment predicates to enforce that
46604661
// all fragments of the instruction are viable.
46614662
list<Predicate> Predicates = !cond(
4663+
!or(!eq(op, "mma.block_scale"),
4664+
!eq(op, "mma.sp.block_scale")) : [hasSM120a, hasPTX<88>],
4665+
46624666
!or(!eq(ptx_elt_type, "e3m2"),
46634667
!eq(ptx_elt_type, "e2m3"),
46644668
!eq(ptx_elt_type, "e2m1"),
@@ -4671,9 +4675,9 @@ class WMMA_REGINFO<WMMA_REGS r, string op, string metadata = "", string kind = "
46714675
!or(!eq(ptx_elt_type, "e4m3"),
46724676
!eq(ptx_elt_type, "e5m2")) : [hasSM<89>, hasPTX<84>],
46734677

4674-
!and(!eq(op, "mma.sp"),
4678+
!and(isSparse,
46754679
!ne(metadata, "sp")) : [hasSM<80>, hasPTX<85>],
4676-
!eq(op, "mma.sp") : [hasSM<80>, hasPTX<71>],
4680+
isSparse : [hasSM<80>, hasPTX<71>],
46774681

46784682
// fp16 -> fp16/fp32 @ m16n16k16
46794683
!and(!eq(geom, "m16n16k16"),
@@ -5027,7 +5031,7 @@ class MMA_BLOCK_SCALE<WMMA_REGINFO FragA, WMMA_REGINFO FragB,
50275031
WMMA_REGINFO FragC, WMMA_REGINFO FragD,
50285032
string Kind, string SType, string ScaleVecSize>
50295033
: WMMA_INSTR<MMA_BLOCK_SCALE_NAME<Kind, SType, ScaleVecSize,
5030-
FragA, FragB, FragC, FragD>.record,
5034+
FragA, FragB, FragC, FragD>.record_name,
50315035
[FragA.Ins, FragB.Ins, FragC.Ins,
50325036
(ins B32:$scale_a, B16:$byte_id_a,
50335037
B16:$thread_id_a, B32:$scale_b,
@@ -5144,7 +5148,7 @@ class MMA_SP_BLOCK_SCALE<WMMA_REGINFO FragA, WMMA_REGINFO FragB,
51445148
WMMA_REGINFO FragC, WMMA_REGINFO FragD,
51455149
string Kind, string SType, string ScaleVecSize>
51465150
: WMMA_INSTR<MMA_SP_BLOCK_SCALE_NAME<Kind, SType, ScaleVecSize,
5147-
FragA, FragB, FragC, FragD>.record,
5151+
FragA, FragB, FragC, FragD>.record_name,
51485152
[FragA.Ins, FragB.Ins, FragC.Ins,
51495153
(ins B32:$metadata, i32imm:$selector,
51505154
B32:$scale_a, B16:$byte_id_a, B16:$thread_id_a,
@@ -5192,10 +5196,10 @@ defset list<WMMA_INSTR> MMA_SP_BLOCK_SCALEs = {
51925196
foreach stype = ["ue8m0", "ue4m3"] in {
51935197
foreach op = NVVM_MMA_OPS.all_mma_sp_block_scale_ops in {
51945198
if NVVM_MMA_SP_BLOCK_SCALE_SUPPORTED<op, kind, stype, scale_vec_size>.ret then {
5195-
def : MMA_SP_BLOCK_SCALE<WMMA_REGINFO<op[0], "mma.sp", "sp::ordered_metadata", kind>,
5196-
WMMA_REGINFO<op[1], "mma.sp", "sp::ordered_metadata", kind>,
5197-
WMMA_REGINFO<op[2], "mma.sp", "sp::ordered_metadata", kind>,
5198-
WMMA_REGINFO<op[3], "mma.sp", "sp::ordered_metadata", kind>,
5199+
def : MMA_SP_BLOCK_SCALE<WMMA_REGINFO<op[0], "mma.sp.block_scale", "sp::ordered_metadata", kind>,
5200+
WMMA_REGINFO<op[1], "mma.sp.block_scale", "sp::ordered_metadata", kind>,
5201+
WMMA_REGINFO<op[2], "mma.sp.block_scale", "sp::ordered_metadata", kind>,
5202+
WMMA_REGINFO<op[3], "mma.sp.block_scale", "sp::ordered_metadata", kind>,
51995203
kind, stype, scale_vec_size>;
52005204
}
52015205
} // op
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
# Check all variants of instructions supported by PTX88 on SM120a
2+
# RUN: %python %s --ptx=88 --gpu-arch=120 --aa > %t-ptx88-sm_120a.ll
3+
# RUN: llc < %t-ptx88-sm_120a.ll -mtriple=nvptx64 -mcpu=sm_120a -mattr=+ptx88 \
4+
# RUN: | FileCheck %t-ptx88-sm_120a.ll
5+
# RUN: %if ptxas-sm_120a && ptxas-isa-8.8 %{ \
6+
# RUN: llc < %t-ptx88-sm_120a.ll -mtriple=nvptx64 -mcpu=sm_120a -mattr=+ptx88 \
7+
# RUN: | %ptxas-verify -arch=sm_120a \
8+
# RUN: %}
9+
10+
import wmma
11+
12+
wmma.main()

llvm/test/CodeGen/NVPTX/wmma.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1246,7 +1246,7 @@ def common_mma_block_scale_test_gen(
12461246

12471247

12481248
def gen_mma_block_scale_tests():
1249-
if not (ptx_version >= 87 and gpu_arch >= 120 and aa):
1249+
if not (ptx_version >= 88 and gpu_arch >= 120 and aa):
12501250
return []
12511251

12521252
mma_block_scale_intrinsic_template = "llvm.nvvm.mma.block.scale.${geom}.row.col.${kind}${scale}.${intrinsic_signature}.${stype}"
@@ -1646,7 +1646,7 @@ def common_mma_sp_block_scale_test_gen(
16461646

16471647

16481648
def gen_mma_sp_block_scale_tests():
1649-
if not (ptx_version >= 87 and gpu_arch >= 120 and aa):
1649+
if not (ptx_version >= 88 and gpu_arch >= 120 and aa):
16501650
return []
16511651

16521652
mma_sp_block_scale_intrinsic_template = "llvm.nvvm.mma.sp.ordered.metadata.block.scale.${geom}.row.col.${kind}${scale}.${intrinsic_signature}.${stype}"

0 commit comments

Comments
 (0)