|
26 | 26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE |
27 | 27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. |
28 | 28 |
|
| 29 | +"""Example usage of the kernel. |
| 30 | +
|
| 31 | +Functional testing: |
| 32 | +python run_blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py \ |
| 33 | + --ab_dtype Float4E2M1FN --c_dtype Float4E2M1FN \ |
| 34 | + --sf_dtype Float8E4M3FN --sf_vec_size 16 \ |
| 35 | + --mma_tiler_mn 128,128 --cluster_shape_mn 1,1 \ |
| 36 | + --nkl 4096,7168,8 --fixed_m 128 |
| 37 | +or use a benchmark file: |
| 38 | +python run_blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py \ |
| 39 | + --ab_dtype Float4E2M1FN --c_dtype Float4E2M1FN \ |
| 40 | + --sf_dtype Float8E4M3FN --sf_vec_size 16 \ |
| 41 | + --mma_tiler_mn 128,128 --cluster_shape_mn 1,1 \ |
| 42 | + --benchmark benchmark.txt |
| 43 | +Perf testing: |
| 44 | +python run_blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py \ |
| 45 | + --ab_dtype Float4E2M1FN --c_dtype Float4E2M1FN \ |
| 46 | + --sf_dtype Float8E4M3FN --sf_vec_size 16 \ |
| 47 | + --mma_tiler_mn 128,128 --cluster_shape_mn 1,1 \ |
| 48 | + --benchmark benchmark.txt \ |
| 49 | + --skip_ref_check --use_cold_l2 --use_cupti --warmup_iterations 10 --iterations 50 |
| 50 | +A sample benchmark.txt file is shown below: |
| 51 | +0 89x4096x7168 |
| 52 | +1 200x4096x7168 |
| 53 | +2 145x4096x7168 |
| 54 | +3 178x4096x7168 |
| 55 | +4 241x4096x7168 |
| 56 | +5 78x4096x7168 |
| 57 | +6 198x4096x7168 |
| 58 | +7 60x4096x7168 |
| 59 | +""" |
| 60 | + |
29 | 61 | import argparse |
30 | 62 | import sys |
31 | 63 | from pathlib import Path |
@@ -577,13 +609,16 @@ def run( |
577 | 609 | raise RuntimeError("GPU is required to run this example!") |
578 | 610 |
|
579 | 611 | # Skip unsupported testcase |
| 612 | + # Note: For grouped GEMM, we use mma_tiler_mn[0] as the m parameter for can_implement check |
| 613 | + # since individual group M values vary |
580 | 614 | if not BlockScaledContiguousGatherGroupedGemmKernel.can_implement( |
581 | 615 | ab_dtype, |
582 | 616 | sf_dtype, |
583 | 617 | sf_vec_size, |
584 | 618 | c_dtype, |
585 | 619 | mma_tiler_mn, |
586 | 620 | cluster_shape_mn, |
| 621 | + mma_tiler_mn[0], # m (use mma_tiler_m as placeholder for grouped GEMM) |
587 | 622 | n, |
588 | 623 | k, |
589 | 624 | num_groups, |
|
0 commit comments