@@ -4619,7 +4619,8 @@ def INT_PTX_SREG_WARPSIZE :
46194619// the fields commonly used to implement specific PTX instruction -- register
46204620// types and names, constraints, parts of assembly, etc.
46214621class WMMA_REGINFO<WMMA_REGS r, string op, string metadata = "", string kind = "">
4622- : WMMA_REGS<r.geom, r.frag, r.ptx_elt_type, !eq(op, "mma.sp")> {
4622+ : WMMA_REGS<r.geom, r.frag, r.ptx_elt_type,
4623+ !or(!eq(op, "mma.sp"), !eq(op, "mma.sp.block_scale"))> {
46234624 // NVPTX register types used to carry fragment data.
46244625 NVPTXRegClass regclass = !cond(
46254626 !eq(ptx_elt_type, "e4m3") : B32,
@@ -4659,6 +4660,9 @@ class WMMA_REGINFO<WMMA_REGS r, string op, string metadata = "", string kind = "
46594660 // longer the case, we can concat all per-fragment predicates to enforce that
46604661 // all fragments of the instruction are viable.
46614662 list<Predicate> Predicates = !cond(
4663+ !or(!eq(op, "mma.block_scale"),
4664+ !eq(op, "mma.sp.block_scale")) : [hasSM120a, hasPTX<88>],
4665+
46624666 !or(!eq(ptx_elt_type, "e3m2"),
46634667 !eq(ptx_elt_type, "e2m3"),
46644668 !eq(ptx_elt_type, "e2m1"),
@@ -4671,9 +4675,9 @@ class WMMA_REGINFO<WMMA_REGS r, string op, string metadata = "", string kind = "
46714675 !or(!eq(ptx_elt_type, "e4m3"),
46724676 !eq(ptx_elt_type, "e5m2")) : [hasSM<89>, hasPTX<84>],
46734677
4674- !and(!eq(op, "mma.sp") ,
4678+ !and(isSparse ,
46754679 !ne(metadata, "sp")) : [hasSM<80>, hasPTX<85>],
4676- !eq(op, "mma.sp") : [hasSM<80>, hasPTX<71>],
4680+ isSparse : [hasSM<80>, hasPTX<71>],
46774681
46784682 // fp16 -> fp16/fp32 @ m16n16k16
46794683 !and(!eq(geom, "m16n16k16"),
@@ -5027,7 +5031,7 @@ class MMA_BLOCK_SCALE<WMMA_REGINFO FragA, WMMA_REGINFO FragB,
50275031 WMMA_REGINFO FragC, WMMA_REGINFO FragD,
50285032 string Kind, string SType, string ScaleVecSize>
50295033 : WMMA_INSTR<MMA_BLOCK_SCALE_NAME<Kind, SType, ScaleVecSize,
5030- FragA, FragB, FragC, FragD>.record ,
5034+ FragA, FragB, FragC, FragD>.record_name ,
50315035 [FragA.Ins, FragB.Ins, FragC.Ins,
50325036 (ins B32:$scale_a, B16:$byte_id_a,
50335037 B16:$thread_id_a, B32:$scale_b,
@@ -5144,7 +5148,7 @@ class MMA_SP_BLOCK_SCALE<WMMA_REGINFO FragA, WMMA_REGINFO FragB,
51445148 WMMA_REGINFO FragC, WMMA_REGINFO FragD,
51455149 string Kind, string SType, string ScaleVecSize>
51465150 : WMMA_INSTR<MMA_SP_BLOCK_SCALE_NAME<Kind, SType, ScaleVecSize,
5147- FragA, FragB, FragC, FragD>.record ,
5151+ FragA, FragB, FragC, FragD>.record_name ,
51485152 [FragA.Ins, FragB.Ins, FragC.Ins,
51495153 (ins B32:$metadata, i32imm:$selector,
51505154 B32:$scale_a, B16:$byte_id_a, B16:$thread_id_a,
@@ -5192,10 +5196,10 @@ defset list<WMMA_INSTR> MMA_SP_BLOCK_SCALEs = {
51925196 foreach stype = ["ue8m0", "ue4m3"] in {
51935197 foreach op = NVVM_MMA_OPS.all_mma_sp_block_scale_ops in {
51945198 if NVVM_MMA_SP_BLOCK_SCALE_SUPPORTED<op, kind, stype, scale_vec_size>.ret then {
5195- def : MMA_SP_BLOCK_SCALE<WMMA_REGINFO<op[0], "mma.sp", "sp::ordered_metadata", kind>,
5196- WMMA_REGINFO<op[1], "mma.sp", "sp::ordered_metadata", kind>,
5197- WMMA_REGINFO<op[2], "mma.sp", "sp::ordered_metadata", kind>,
5198- WMMA_REGINFO<op[3], "mma.sp", "sp::ordered_metadata", kind>,
5199+ def : MMA_SP_BLOCK_SCALE<WMMA_REGINFO<op[0], "mma.sp.block_scale ", "sp::ordered_metadata", kind>,
5200+ WMMA_REGINFO<op[1], "mma.sp.block_scale ", "sp::ordered_metadata", kind>,
5201+ WMMA_REGINFO<op[2], "mma.sp.block_scale ", "sp::ordered_metadata", kind>,
5202+ WMMA_REGINFO<op[3], "mma.sp.block_scale ", "sp::ordered_metadata", kind>,
51995203 kind, stype, scale_vec_size>;
52005204 }
52015205 } // op
0 commit comments