@@ -119,8 +119,8 @@ def shared_memory_kernel(XBLOCK: ttgl.constexpr, YBLOCK: ttgl.constexpr, layout_
119
119
layout_b : ttgl .constexpr , smem_layout : ttgl .constexpr ):
120
120
unused = ttgl .allocate_shared_memory (ttgl .int32 , [XBLOCK , YBLOCK ], smem_layout )
121
121
a = ttgl .full ([XBLOCK , YBLOCK ], 0 , ttgl .int32 , layout_a )
122
- tl .static_assert (a .numel == unused .numel )
123
- tl .static_assert (unused .numel == XBLOCK * YBLOCK )
122
+ ttgl .static_assert (a .numel == unused .numel )
123
+ ttgl .static_assert (unused .numel == XBLOCK * YBLOCK )
124
124
mem = ttgl .allocate_shared_memory (ttgl .int32 , a .shape , smem_layout , a )
125
125
b = mem .load (layout_b ) # noqa: F841
126
126
mem .store (a )
@@ -641,7 +641,7 @@ def async_tma_kernel(input_desc, XBLOCK: ttgl.constexpr):
641
641
mbarrier .init (bar , count = 1 )
642
642
643
643
tma .async_copy_global_to_shared (input_desc , [0 , 0 ], bar , smem )
644
- tl .static_assert (input_desc .block_type .nbytes == XBLOCK * XBLOCK * 2 )
644
+ ttgl .static_assert (input_desc .block_type .nbytes == XBLOCK * XBLOCK * 2 )
645
645
mbarrier .expect (bar , input_desc .block_type .nbytes )
646
646
mbarrier .wait (bar , 0 )
647
647
@@ -941,7 +941,7 @@ def reduce_kernel(out):
941
941
ttgl .static_assert (pairs [0 ].type .layout == ttgl .SliceLayout (0 , layout ))
942
942
ttgl .static_assert (pairs [1 ].type .layout == ttgl .SliceLayout (0 , layout ))
943
943
result = scalar + s1 + pairs [0 ] + pairs [1 ]
944
- tl .store (out + ttgl .arange (0 , 16 , s0 .type .layout ), result )
944
+ ttgl .store (out + ttgl .arange (0 , 16 , s0 .type .layout ), result )
945
945
946
946
947
947
@pytest .mark .parametrize ("target" , ALL_TARGETS )
@@ -1057,8 +1057,8 @@ def test_elementwise_core():
1057
1057
1058
1058
@gluon .jit
1059
1059
def linear_layout_kernel ():
1060
- ll : tl .constexpr = ttgl .DistributedLinearLayout (reg_bases = [[1 ]], lane_bases = [[2 ], [4 ], [8 ], [16 ], [32 ]],
1061
- warp_bases = [[64 ], [128 ]], block_bases = [], shape = [256 ])
1060
+ ll : ttgl .constexpr = ttgl .DistributedLinearLayout (reg_bases = [[1 ]], lane_bases = [[2 ], [4 ], [8 ], [16 ], [32 ]],
1061
+ warp_bases = [[64 ], [128 ]], block_bases = [], shape = [256 ])
1062
1062
ttgl .arange (0 , 256 , layout = ll )
1063
1063
1064
1064
@@ -1077,6 +1077,20 @@ def test_linear_layout(target):
1077
1077
""" )
1078
1078
1079
1079
1080
+ @filecheck_test
1081
+ @gluon .jit
1082
+ def test_dot_operand_layout ():
1083
+ # CHECK: [[NVMMA:#.*]] = #ttg.nvidia_mma
1084
+ # CHECK: test_dot_operand_layout
1085
+ mma_layout : ttgl .constexpr = ttgl .NVMMADistributedLayout (version = [3 , 0 ], warps_per_cta = [4 , 1 ],
1086
+ instr_shape = [16 , 32 , 16 ])
1087
+ layout : ttgl .constexpr = ttgl .DotOperandLayout (operand_index = 0 , parent = mma_layout , k_width = 2 )
1088
+ # CHECK: arith.constant {{.*}} tensor<256x128xf16, #ttg.dot_op<{opIdx = 0, parent = [[NVMMA]], kWidth = 2}>>
1089
+ x = ttgl .full ([256 , 128 ], 0.0 , ttgl .float16 , layout )
1090
+ y = x .sum (axis = 1 )
1091
+ ttgl .static_assert (y .type .layout .parent == layout )
1092
+
1093
+
1080
1094
@filecheck_test
1081
1095
@gluon .jit
1082
1096
def test_tensor_permute ():
@@ -1201,7 +1215,7 @@ def async_copy_kernel(inp, xnumel, XBLOCK: ttgl.constexpr):
1201
1215
smem = ttgl .allocate_shared_memory (inp .dtype .element_ty , [XBLOCK ], ttgl .SwizzledSharedLayout (1 , 1 , 1 , order = [0 ]))
1202
1216
block_layout : ttgl .constexpr = ttgl .BlockedLayout ([2 ], [32 ], [4 ], [0 ])
1203
1217
xindex = ttgl .arange (0 , XBLOCK , block_layout )
1204
- mask = tl .max_constancy (xindex < xnumel , 2 )
1218
+ mask = ttgl .max_constancy (xindex < xnumel , 2 )
1205
1219
1206
1220
async_copy .async_copy_global_to_shared (smem , inp + xindex )
1207
1221
async_copy .async_copy_global_to_shared (smem , inp + xindex , mask , cache_modifier = ".ca" , eviction_policy = "evict_last" ,
0 commit comments