Skip to content

Commit e129576

Browse files
committed
Add example usage docstring for blockscaled contiguous gather grouped gemm swiglu fusion kernel
Signed-off-by: Zongfei Jing <20381269+zongfeijing@users.noreply.github.com>
1 parent 8c3aa4e commit e129576

File tree

1 file changed

+35
-0
lines changed

1 file changed

+35
-0
lines changed

tests/scripts/cute_dsl_kernels/run_blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,38 @@
2626
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
2727
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2828

29+
"""Example usage of the kernel.
30+
31+
Functional testing:
32+
python run_blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py \
33+
--ab_dtype Float4E2M1FN --c_dtype Float4E2M1FN \
34+
--sf_dtype Float8E4M3FN --sf_vec_size 16 \
35+
--mma_tiler_mn 128,128 --cluster_shape_mn 1,1 \
36+
--nkl 4096,7168,8 --fixed_m 128
37+
or use a benchmark file:
38+
python run_blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py \
39+
--ab_dtype Float4E2M1FN --c_dtype Float4E2M1FN \
40+
--sf_dtype Float8E4M3FN --sf_vec_size 16 \
41+
--mma_tiler_mn 128,128 --cluster_shape_mn 1,1 \
42+
--benchmark benchmark.txt
43+
Perf testing:
44+
python run_blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py \
45+
--ab_dtype Float4E2M1FN --c_dtype Float4E2M1FN \
46+
--sf_dtype Float8E4M3FN --sf_vec_size 16 \
47+
--mma_tiler_mn 128,128 --cluster_shape_mn 1,1 \
48+
--benchmark benchmark.txt \
49+
--skip_ref_check --use_cold_l2 --use_cupti --warmup_iterations 10 --iterations 50
50+
A sample benchmark.txt file is shown below:
51+
0 89x4096x7168
52+
1 200x4096x7168
53+
2 145x4096x7168
54+
3 178x4096x7168
55+
4 241x4096x7168
56+
5 78x4096x7168
57+
6 198x4096x7168
58+
7 60x4096x7168
59+
"""
60+
2961
import argparse
3062
import sys
3163
from pathlib import Path
@@ -577,13 +609,16 @@ def run(
577609
raise RuntimeError("GPU is required to run this example!")
578610

579611
# Skip unsupported testcase
612+
# Note: For grouped GEMM, we use mma_tiler_mn[0] as the m parameter for can_implement check
613+
# since individual group M values vary
580614
if not BlockScaledContiguousGatherGroupedGemmKernel.can_implement(
581615
ab_dtype,
582616
sf_dtype,
583617
sf_vec_size,
584618
c_dtype,
585619
mma_tiler_mn,
586620
cluster_shape_mn,
621+
mma_tiler_mn[0], # m (use mma_tiler_m as placeholder for grouped GEMM)
587622
n,
588623
k,
589624
num_groups,

0 commit comments

Comments
 (0)