Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 32 additions & 7 deletions llvm/include/llvm/IR/IntrinsicsNVVM.td
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ class WMMA_REGS<string Geom, string Frag, string PtxEltType> {
string frag = Frag;
string ptx_elt_type = PtxEltType;
string gft = Geom#":"#Frag#":"#ptx_elt_type;
string gf = Geom#":"#Frag;
string ft = frag#":"#ptx_elt_type;
list<LLVMType> regs = !cond(
// mma fp ops use smaller fragments than wmma fp ops
Expand Down Expand Up @@ -214,9 +215,19 @@ class WMMA_REGS<string Geom, string Frag, string PtxEltType> {
!eq(gft,"m16n8k256:d:s32") : !listsplat(llvm_i32_ty, 4),

// ldmatrix b16 -> s32 @ m8n8
!eq(gft,"m8n8:x1:b16") : !listsplat(llvm_i32_ty, 1),
!eq(gft,"m8n8:x2:b16") : !listsplat(llvm_i32_ty, 2),
!eq(gft,"m8n8:x4:b16") : !listsplat(llvm_i32_ty, 4),
!eq(gf,"m8n8:x1") : !listsplat(llvm_i32_ty, 1),
!eq(gf,"m8n8:x2") : !listsplat(llvm_i32_ty, 2),
!eq(gf,"m8n8:x4") : !listsplat(llvm_i32_ty, 4),

// ldmatrix b8, b8x16.b6x16_p32, b8x16.b4x16_p64 -> s32 @ m16n16
!eq(gf,"m16n16:x1") : !listsplat(llvm_i32_ty, 2),
!eq(gf,"m16n16:x2") : !listsplat(llvm_i32_ty, 4),

// ldmatrix b8x16.b6x16_p32, b8x16.b4x16_p64 -> s32 @ m8n16
!eq(gf,"m8n16:x1") : !listsplat(llvm_i32_ty, 1),
!eq(gf,"m8n16:x2") : !listsplat(llvm_i32_ty, 2),
!eq(gf,"m8n16:x4") : !listsplat(llvm_i32_ty, 4),

);
}

Expand Down Expand Up @@ -421,7 +432,16 @@ class NVVM_MMA_OPS {

list<WMMA_REGS> ldmatrix_b16_ops = LDMATRIX_OPS<
["m8n8"], ["x1", "x2", "x4"], ["b16"]>.ret;
list<WMMA_REGS> all_ldmatrix_ops = ldmatrix_b16_ops;

list<WMMA_REGS> ldmatrix_geom_m16n16_ops = LDMATRIX_OPS<
["m16n16"], ["x1", "x2"], ["b8", "b8x16.b6x16_p32", "b8x16.b4x16_p64"]>.ret;

list<WMMA_REGS> ldmatrix_geom_m8n16_ops = LDMATRIX_OPS<
["m8n16"], ["x1", "x2", "x4"], ["b8x16.b6x16_p32", "b8x16.b4x16_p64"]>.ret;

list<WMMA_REGS> all_ldmatrix_ops = !listconcat(ldmatrix_b16_ops,
ldmatrix_geom_m16n16_ops,
ldmatrix_geom_m8n16_ops);
}

def NVVM_MMA_OPS : NVVM_MMA_OPS;
Expand Down Expand Up @@ -546,13 +566,18 @@ class NVVM_MMA_SUPPORTED<list<WMMA_REGS> frags, string layout_a, string layout_b
// if NVVM_LDMATRIX_SUPPORTED<...>.ret then
// def : FOO<>; // The record will only be defined for supported ops.
//
class NVVM_LDMATRIX_SUPPORTED<WMMA_REGS frag> {
class NVVM_LDMATRIX_SUPPORTED<WMMA_REGS frag, bit trans> {
string g = frag.geom;
string t = frag.ptx_elt_type;

bit ret = !cond(
// Only currently support m8n8 and b16
!and(!eq(g, "m8n8"), !eq(t, "b16")): true,
!and(!eq(g, "m16n16"), !eq(t, "b8"), !eq(trans, 1)): true,
!and(!eq(g, "m16n16"), !eq(t, "b8x16.b6x16_p32"), !eq(trans, 1)): true,
!and(!eq(g, "m16n16"), !eq(t, "b8x16.b4x16_p64"), !eq(trans, 1)): true,
!and(!eq(g, "m8n16"), !eq(t, "b8"), !eq(trans, 0)): true,
!and(!eq(g, "m8n16"), !eq(t, "b8x16.b6x16_p32"), !eq(trans, 0)): true,
!and(!eq(g, "m8n16"), !eq(t, "b8x16.b4x16_p64"), !eq(trans, 0)): true,
true: false
);
}
Expand Down Expand Up @@ -4983,7 +5008,7 @@ class NVVM_LDMATRIX<WMMA_REGS Frag, int Transposed>

foreach transposed = [0, 1] in {
foreach frag = NVVM_MMA_OPS.all_ldmatrix_ops in {
if NVVM_LDMATRIX_SUPPORTED<frag>.ret then {
if NVVM_LDMATRIX_SUPPORTED<frag, transposed>.ret then {
def LDMATRIX_NAME<frag, transposed>.record
: NVVM_LDMATRIX<frag, transposed>;
}
Expand Down
18 changes: 15 additions & 3 deletions llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3668,7 +3668,12 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
case Intrinsic::nvvm_wmma_m16n16k8_load_b_tf32_row:
case Intrinsic::nvvm_wmma_m16n16k8_load_b_tf32_row_stride:
case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_b16:
case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_trans_b16: {
case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_trans_b16:
case Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8:
case Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8x16_b4x16_p64:
case Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8x16_b6x16_p32:
case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n16_x4_b8x16_b4x16_p64:
case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n16_x4_b8x16_b6x16_p32: {
Info.opc = ISD::INTRINSIC_W_CHAIN;
Info.memVT = MVT::v4i32;
Info.ptrVal = I.getArgOperand(0);
Expand Down Expand Up @@ -3708,7 +3713,9 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
case Intrinsic::nvvm_wmma_m8n8k32_load_b_u4_col_stride:
case Intrinsic::nvvm_wmma_m8n8k32_load_b_u4_col:
case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_b16:
case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_trans_b16: {
case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_trans_b16:
case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n16_x1_b8x16_b4x16_p64:
case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n16_x1_b8x16_b6x16_p32: {
Info.opc = ISD::INTRINSIC_W_CHAIN;
Info.memVT = MVT::i32;
Info.ptrVal = I.getArgOperand(0);
Expand Down Expand Up @@ -3804,7 +3811,12 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
case Intrinsic::nvvm_wmma_m8n8k32_load_c_s32_row:
case Intrinsic::nvvm_wmma_m8n8k32_load_c_s32_row_stride:
case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_b16:
case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_trans_b16: {
case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_trans_b16:
case Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8:
case Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8x16_b4x16_p64:
case Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8x16_b6x16_p32:
case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n16_x2_b8x16_b4x16_p64:
case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n16_x2_b8x16_b6x16_p32: {
Info.opc = ISD::INTRINSIC_W_CHAIN;
Info.memVT = MVT::v2i32;
Info.ptrVal = I.getArgOperand(0);
Expand Down
27 changes: 25 additions & 2 deletions llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
Original file line number Diff line number Diff line change
Expand Up @@ -7052,6 +7052,9 @@ class WMMA_REGINFO<WMMA_REGS r, string op>
!eq(ptx_elt_type, "tf32") : Int32Regs,
!eq(ptx_elt_type, "s32") : Int32Regs,
!eq(ptx_elt_type, "b16") : Int32Regs,
!eq(ptx_elt_type, "b8") : Int32Regs,
!eq(ptx_elt_type, "b8x16.b6x16_p32") : Int32Regs,
!eq(ptx_elt_type, "b8x16.b4x16_p64") : Int32Regs,
!eq(ptx_elt_type, "s8") : Int32Regs,
!eq(ptx_elt_type, "u8") : Int32Regs,
!eq(ptx_elt_type, "s4") : Int32Regs,
Expand Down Expand Up @@ -7139,7 +7142,27 @@ class WMMA_REGINFO<WMMA_REGS r, string op>

!and(!eq(op,"ldmatrix"),
!eq(ptx_elt_type,"b16"),
!eq(geom, "m8n8")) : [hasSM<75>, hasPTX<65>]);
!eq(geom, "m8n8")) : [hasSM<75>, hasPTX<65>],

!and(!eq(op,"ldmatrix"),
!eq(ptx_elt_type,"b8"),
!eq(geom, "m16n16")) : [hasSM<100>, hasArchAccelFeatures, hasPTX<86>],

!and(!eq(op,"ldmatrix"),
!eq(ptx_elt_type,"b8x16.b6x16_p32"),
!eq(geom, "m16n16")) : [hasSM<100>, hasArchAccelFeatures, hasPTX<86>],

!and(!eq(op,"ldmatrix"),
!eq(ptx_elt_type,"b8x16.b4x16_p64"),
!eq(geom, "m16n16")) : [hasSM<100>, hasArchAccelFeatures, hasPTX<86>],

!and(!eq(op,"ldmatrix"),
!eq(ptx_elt_type,"b8x16.b6x16_p32"),
!eq(geom, "m8n16")) : [hasSM<100>, hasArchAccelFeatures, hasPTX<86>],

!and(!eq(op,"ldmatrix"),
!eq(ptx_elt_type,"b8x16.b4x16_p64"),
!eq(geom, "m8n16")) : [hasSM<100>, hasArchAccelFeatures, hasPTX<86>]);

// template DAGs for instruction inputs/output.
dag Outs = !dag(outs, ptx_regs, reg_names);
Expand Down Expand Up @@ -7414,7 +7437,7 @@ defset list<WMMA_INSTR> LDMATRIXs = {
foreach transposed = [false, true] in {
foreach space = [".shared", ""] in {
foreach frag = NVVM_MMA_OPS.all_ldmatrix_ops in
if NVVM_LDMATRIX_SUPPORTED<frag>.ret then
if NVVM_LDMATRIX_SUPPORTED<frag, transposed>.ret then
def : LDMATRIX<WMMA_REGINFO<frag, "ldmatrix">, transposed, space>;
} // space
} // transposed
Expand Down
16 changes: 16 additions & 0 deletions llvm/test/CodeGen/NVPTX/wmma-ptx86-sm100a.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Check all variants of instructions supported by PTX86 on SM100a
# RUN: %python %s --ptx=86 --gpu-arch=100 --aa > %t-ptx86-sm_100a.ll
# RUN: FileCheck %t-ptx86-sm_100a.ll < %t-ptx86-sm_100a.ll \
# RUN: --check-prefixes=PTX86LDMATRIX-DAG
# RUN: FileCheck %t-ptx86-sm_100a.ll < %t-ptx86-sm_100a.ll \
# RUN: --check-prefixes=PTX86LDMATRIX-DAG
# RUN: llc < %t-ptx86-sm_100a.ll -mtriple=nvptx64 -mcpu=sm_100a -mattr=+ptx86 \
# RUN: | FileCheck %t-ptx86-sm_100a.ll
# RUN: %if ptxas-12.7 %{ \
# RUN: llc < %t-ptx86-sm_100a.ll -mtriple=nvptx64 -mcpu=sm_100a -mattr=+ptx86 \
# RUN: | %ptxas-verify -arch=sm_100a \
# RUN: %}

import wmma

wmma.main()
16 changes: 16 additions & 0 deletions llvm/test/CodeGen/NVPTX/wmma-ptx86-sm101a.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Check all variants of instructions supported by PTX86 on SM101a
# RUN: %python %s --ptx=86 --gpu-arch=101 --aa > %t-ptx86-sm_101a.ll
# RUN: FileCheck %t-ptx86-sm_101a.ll < %t-ptx86-sm_101a.ll \
# RUN: --check-prefixes=PTX86LDMATRIX-DAG
# RUN: FileCheck %t-ptx86-sm_101a.ll < %t-ptx86-sm_101a.ll \
# RUN: --check-prefixes=PTX86LDMATRIX-DAG
# RUN: llc < %t-ptx86-sm_101a.ll -mtriple=nvptx64 -mcpu=sm_101a -mattr=+ptx86 \
# RUN: | FileCheck %t-ptx86-sm_101a.ll
# RUN: %if ptxas-12.7 %{ \
# RUN: llc < %t-ptx86-sm_101a.ll -mtriple=nvptx64 -mcpu=sm_101a -mattr=+ptx86 \
# RUN: | %ptxas-verify -arch=sm_101a \
# RUN: %}

import wmma

wmma.main()
16 changes: 16 additions & 0 deletions llvm/test/CodeGen/NVPTX/wmma-ptx86-sm120a.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Check all variants of instructions supported by PTX86 on SM120a
# RUN: %python %s --ptx=86 --gpu-arch=120 --aa > %t-ptx86-sm_120a.ll
# RUN: FileCheck %t-ptx86-sm_120a.ll < %t-ptx86-sm_120a.ll \
# RUN: --check-prefixes=PTX86LDMATRIX-DAG
# RUN: FileCheck %t-ptx86-sm_120a.ll < %t-ptx86-sm_120a.ll \
# RUN: --check-prefixes=PTX86LDMATRIX-DAG
# RUN: llc < %t-ptx86-sm_120a.ll -mtriple=nvptx64 -mcpu=sm_120a -mattr=+ptx86 \
# RUN: | FileCheck %t-ptx86-sm_120a.ll
# RUN: %if ptxas-12.7 %{ \
# RUN: llc < %t-ptx86-sm_120a.ll -mtriple=nvptx64 -mcpu=sm_120a -mattr=+ptx86 \
# RUN: | %ptxas-verify -arch=sm_120a \
# RUN: %}

import wmma

wmma.main()
59 changes: 56 additions & 3 deletions llvm/test/CodeGen/NVPTX/wmma.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ def __init__(self, ptx_type):
"f64": "double",
"s32": "i32",
"b16": "i32",
"b8": "i32",
"b8x16.b6x16_p32": "i32",
"b8x16.b4x16_p64": "i32",
"s8": "i32",
"u8": "i32",
"s4": "i32",
Expand Down Expand Up @@ -161,6 +164,18 @@ def __init__(self, geom, frag, ptx_elt_type):
"m8n8:x1:b16": 1,
"m8n8:x2:b16": 2,
"m8n8:x4:b16": 4,
"m16n16:x1:b8": 2,
"m16n16:x2:b8": 4,
"m16n16:x1:b8x16.b6x16_p32": 2,
"m16n16:x2:b8x16.b6x16_p32": 4,
"m16n16:x1:b8x16.b4x16_p64": 2,
"m16n16:x2:b8x16.b4x16_p64": 4,
"m8n16:x1:b8x16.b6x16_p32": 1,
"m8n16:x2:b8x16.b6x16_p32": 2,
"m8n16:x4:b8x16.b6x16_p32": 4,
"m8n16:x1:b8x16.b4x16_p64": 1,
"m8n16:x2:b8x16.b4x16_p64": 2,
"m8n16:x4:b8x16.b4x16_p64": 4,
}.get(
"%s:%s:%s" % (geom, frag, ptx_elt_type),
{
Expand Down Expand Up @@ -289,7 +304,15 @@ def get_ldst_ops(kind):


def get_ldmatrix_ops():
return make_ldmatrix_ops(["m8n8"], ["x1", "x2", "x4"], ["b16"])
return (
make_ldmatrix_ops(["m8n8"], ["x1", "x2", "x4"], ["b16"])
+ make_ldmatrix_ops(
["m16n16"], ["x1", "x2"], ["b8", "b8x16.b6x16_p32", "b8x16.b4x16_p64"]
)
+ make_ldmatrix_ops(
["m8n16"], ["x1", "x2", "x4"], ["b8x16.b6x16_p32", "b8x16.b4x16_p64"]
)
)


def is_wmma_geom_supported(geom):
Expand Down Expand Up @@ -330,9 +353,22 @@ def is_mma_geom_supported(geom):
def is_ldmatrix_geom_supported(geom):
if geom in ["m8n8"]:
return ptx_version >= 65 and gpu_arch >= 75
elif geom in ["m16n16"]:
return ptx_version >= 86 and gpu_arch >= 100 and aa
elif geom in ["m8n16"]:
return ptx_version >= 86 and gpu_arch >= 100 and aa
assert False # Unexpected geometry.


def is_ldmatrix_trans_supported(geom, trans):
if geom in ["m8n8"]:
return True
elif geom in ["m16n16"]:
return trans == ".trans"
elif geom in ["m8n16"]:
return trans == ""
assert False # Unexpected geometry.

def is_type_supported(ptx_type):
if ptx_type in ["s8", "u8", "s32"]:
return ptx_version >= 63 and gpu_arch >= 72
Expand Down Expand Up @@ -417,10 +453,11 @@ def is_ldst_variant_supported(frag, layout):
return True


def is_ldmatrix_variant_supported(frag):
def is_ldmatrix_variant_supported(frag, trans):
if not (
is_type_supported(frag.mma_type.ptx_type)
and is_ldmatrix_geom_supported(frag.geom)
and is_ldmatrix_trans_supported(frag.geom, trans)
):
return False
return frag.frag in ["x1", "x2", "x4"]
Expand Down Expand Up @@ -653,7 +690,7 @@ def gen_ldmatrix_tests():
["", ".shared"],
["", ".trans"],
):
if not is_ldmatrix_variant_supported(frag):
if not is_ldmatrix_variant_supported(frag, trans):
continue

params = {
Expand Down Expand Up @@ -944,6 +981,19 @@ def gen_check_unsupported_ops(items):
; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x2.trans.shared.b16
; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16

; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m16n16.x1.trans.shared.b8
; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m16n16.x2.trans.shared.b8
; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m16n16.x1.trans.b8x16.b6x16_p32
; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m16n16.x1.trans.b8x16.b4x16_p64
; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m16n16.x2.trans.b8x16.b6x16_p32
; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m16n16.x2.trans.b8x16.b4x16_p64
; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m8n16.x1.b8x16.b6x16_p32
; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m8n16.x1.b8x16.b4x16_p64
; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m8n16.x2.b8x16.b6x16_p32
; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m8n16.x2.b8x16.b4x16_p64
; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m8n16.x4.b8x16.b6x16_p32
; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m8n16.x4.b8x16.b4x16_p64

; PTX71MMA-DAG: mma.m8n8k4.row.col.f64
; PTX71MMA-DAG: mma.m16n8k4.row.col.tf32
; PTX71MMA-DAG: mma.m16n8k8.row.col.tf32
Expand Down Expand Up @@ -997,13 +1047,16 @@ def gen_tests():
def main():
global ptx_version
global gpu_arch
global aa
parser = argparse.ArgumentParser()
parser.add_argument("--ptx", type=int, default=60)
parser.add_argument("--gpu-arch", type=int, default=70)
parser.add_argument("--aa", action="store_true")
args = parser.parse_args()

ptx_version = args.ptx
gpu_arch = args.gpu_arch
aa = args.aa

gen_tests()

Expand Down
Loading