Skip to content

Commit dd82e31

Browse files
kvedernigithub-actions[bot]
authored andcommitted
Automerge: [MLIR][NVVM] Support for dense and sparse MMA with block scaling (#170566)
This change adds dense and sparse MMA with block scaling intrinsics to MLIR -> NVVM IR -> NVPTX flow. NVVM and NVPTX implementation is based on PTX ISA 9.0.
2 parents defd7de + 07eb9fa commit dd82e31

File tree

9 files changed

+2493
-267
lines changed

9 files changed

+2493
-267
lines changed

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

Lines changed: 458 additions & 35 deletions
Large diffs are not rendered by default.

mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp

Lines changed: 677 additions & 36 deletions
Large diffs are not rendered by default.

mlir/test/Dialect/LLVMIR/nvvm-mma-blockscale.mlir

Lines changed: 525 additions & 0 deletions
Large diffs are not rendered by default.

mlir/test/Dialect/LLVMIR/nvvm-mma-sparse-blockscale.mlir

Lines changed: 637 additions & 0 deletions
Large diffs are not rendered by default.

mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-block-scale-shared.mlir

Lines changed: 48 additions & 48 deletions
Large diffs are not rendered by default.

mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-block-scale-tensor.mlir

Lines changed: 48 additions & 48 deletions
Large diffs are not rendered by default.

mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-invalid.mlir

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ llvm.func @nvvm_tcgen05_mma_ashift(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>
4444
llvm.func @nvvm_tcgen05_mma_mxf4nvf4_block_scale_default(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %scalea: !llvm.ptr<6>, %scaleb: !llvm.ptr<6>) {
4545
// expected-error @below {{mxf4nvf4 requires block scale attribute}}
4646
nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scalea, %scaleb
47-
{kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
47+
{kind = #nvvm.block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
4848
llvm.return
4949
}
5050

@@ -54,7 +54,7 @@ llvm.func @nvvm_tcgen05_mma_mxf4nvf4_block_scale_default(%d_tmem : !llvm.ptr<6>,
5454
llvm.func @nvvm_tcgen05_mma_mxf4_block_scale_default(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %scalea: !llvm.ptr<6>, %scaleb: !llvm.ptr<6>) {
5555
// expected-error @below {{mxf4 kind does not support block16 attribute}}
5656
nvvm.tcgen05.mma.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %scalea, %scaleb
57-
{kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>, ashift, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
57+
{kind = #nvvm.block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>, ashift, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
5858
llvm.return
5959
}
6060

@@ -104,7 +104,7 @@ llvm.func @nvvm_tcgen05_mma_sp_ashift(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr
104104
llvm.func @nvvm_tcgen05_mma_sp_mxf4nvf4_block_scale_default(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %scalea: !llvm.ptr<6>, %scaleb: !llvm.ptr<6>, %spmetadata: !llvm.ptr<6>) {
105105
// expected-error @below {{mxf4nvf4 requires block scale attribute}}
106106
nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scalea, %scaleb
107-
{kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
107+
{kind = #nvvm.block_scale_kind<mxf4nvf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>, aShift} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
108108
llvm.return
109109
}
110110

@@ -114,6 +114,6 @@ llvm.func @nvvm_tcgen05_mma_sp_mxf4nvf4_block_scale_default(%d_tmem : !llvm.ptr<
114114
llvm.func @nvvm_tcgen05_mma_sp_mxf4_block_scale_default(%d_tmem : !llvm.ptr<6>, %a_tmem: !llvm.ptr<6>, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %scalea: !llvm.ptr<6>, %scaleb: !llvm.ptr<6>, %spmetadata: !llvm.ptr<6>) {
115115
// expected-error @below {{mxf4 kind does not support block16 attribute}}
116116
nvvm.tcgen05.mma.sp.block_scale %d_tmem, %a_tmem, %b_desc, %idesc, %enable_input_d, %spmetadata, %scalea, %scaleb
117-
{kind = #nvvm.tcgen05_mma_block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>, ashift, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
117+
{kind = #nvvm.block_scale_kind<mxf4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>, ashift, blockScale = #nvvm.tcgen05_mma_block_scale<block16>} : (!llvm.ptr<6>, !llvm.ptr<6>, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>, !llvm.ptr<6>)
118118
llvm.return
119119
}

mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-sp-block-scale-shared.mlir

Lines changed: 48 additions & 48 deletions
Large diffs are not rendered by default.

mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-sp-block-scale-tensor.mlir

Lines changed: 48 additions & 48 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)