@@ -1336,37 +1336,36 @@ def test_static_assert():
13361336
13371337
13381338@pytest .mark .parametrize ("reg_layout, shared_layout, shape, bitwidth, ref_conflicts" , [
1339- (ttgl .BlockedLayout ([1 ], [32 ], [4 ], [0 ]), ttgl .SwizzledSharedLayout (1 , 1 , 1 , order = [0 ]), [32 ], 32 , 1 ),
1340- # FIXME: This one should be zero conflicts due to broadcasting.
1341- (ttgl .BlockedLayout ([1 ], [32 ], [4 ], [0 ]), ttgl .SwizzledSharedLayout (1 , 1 , 1 , order = [0 ]), [32 ], 16 , 2 ),
1339+ (ttgl .BlockedLayout ([1 ], [32 ], [4 ], [0 ]), ttgl .SwizzledSharedLayout (1 , 1 , 1 , order = [0 ]), [32 ], 32 , 0 ),
1340+ (ttgl .BlockedLayout ([1 ], [32 ], [4 ], [0 ]), ttgl .SwizzledSharedLayout (1 , 1 , 1 , order = [0 ]), [32 ], 16 , 0 ),
13421341 # MMAv3 accumulator tile lowered with the 128B swizzle (WGMMA default path).
13431342 (ttgl .NVMMADistributedLayout (version = [3 , 0 ], warps_per_cta = [4 , 1 ], instr_shape = [16 , 32 , 16 ]),
1344- ttgl .NVMMASharedLayout (swizzle_byte_width = 128 , element_bitwidth = 16 , rank = 2 ), [128 , 128 ], 16 , 1 ),
1343+ ttgl .NVMMASharedLayout (swizzle_byte_width = 128 , element_bitwidth = 16 , rank = 2 ), [128 , 128 ], 16 , 0 ),
13451344 # Small-M tiles disable swizzling entirely.
13461345 # MMAv2 rhs operand emitted with the 64B swizzle.
13471346 (ttgl .DotOperandLayout (
13481347 operand_index = 1 , parent = ttgl .NVMMADistributedLayout (version = [2 , 0 ], warps_per_cta = [1 , 4 ], instr_shape = [16 , 8 ]),
1349- k_width = 2 ), ttgl .NVMMASharedLayout (swizzle_byte_width = 64 , element_bitwidth = 16 , rank = 2 ), [64 , 32 ], 16 , 2 ),
1348+ k_width = 2 ), ttgl .NVMMASharedLayout (swizzle_byte_width = 64 , element_bitwidth = 16 , rank = 2 ), [64 , 32 ], 16 , 0 ),
13501349 # MMAv2 lhs operand uses the transposed 64B swizzle flavour.
13511350 (ttgl .DotOperandLayout (
13521351 operand_index = 0 , parent = ttgl .NVMMADistributedLayout (version = [2 , 0 ], warps_per_cta = [1 , 4 ], instr_shape = [16 , 8 ]),
13531352 k_width = 2 ), ttgl .NVMMASharedLayout (swizzle_byte_width = 64 , element_bitwidth = 16 , rank = 2 ,
1354- transposed = True ), [32 , 64 ], 16 , 2 ),
1353+ transposed = True ), [32 , 64 ], 16 , 0 ),
13551354 # int8 tensor-core tiles follow the 32B swizzle path.
13561355 (ttgl .DotOperandLayout (
13571356 operand_index = 1 , parent = ttgl .NVMMADistributedLayout (version = [2 , 0 ], warps_per_cta = [1 , 4 ], instr_shape = [16 , 8 ]),
1358- k_width = 1 ), ttgl .NVMMASharedLayout (swizzle_byte_width = 32 , element_bitwidth = 8 , rank = 2 ), [8 , 32 ], 8 , 4 ),
1357+ k_width = 1 ), ttgl .NVMMASharedLayout (swizzle_byte_width = 32 , element_bitwidth = 8 , rank = 2 ), [8 , 32 ], 8 , 0 ),
13591358 # Small-M tiles disable swizzling entirely.
13601359 (ttgl .NVMMADistributedLayout (version = [2 , 0 ], warps_per_cta = [4 , 1 ], instr_shape = [16 , 8 ]),
1361- ttgl .NVMMASharedLayout (swizzle_byte_width = 64 , element_bitwidth = 16 , rank = 2 , transposed = True ), [64 , 64 ], 16 , 2 ),
1360+ ttgl .NVMMASharedLayout (swizzle_byte_width = 64 , element_bitwidth = 16 , rank = 2 , transposed = True ), [64 , 64 ], 16 , 0 ),
13621361 (ttgl .NVMMADistributedLayout (version = [3 , 0 ], warps_per_cta = [2 , 2 ], instr_shape = [16 , 32 , 16 ]),
1363- ttgl .NVMMASharedLayout (swizzle_byte_width = 64 , element_bitwidth = 16 , rank = 2 ), [64 , 32 ], 16 , 1 ),
1362+ ttgl .NVMMASharedLayout (swizzle_byte_width = 64 , element_bitwidth = 16 , rank = 2 ), [64 , 32 ], 16 , 0 ),
13641363 (ttgl .NVMMADistributedLayout (version = [2 , 0 ], warps_per_cta = [4 , 1 ], instr_shape = [16 , 8 ]),
1365- ttgl .NVMMASharedLayout (swizzle_byte_width = 32 , element_bitwidth = 8 , rank = 2 ), [32 , 32 ], 8 , 2 ),
1364+ ttgl .NVMMASharedLayout (swizzle_byte_width = 32 , element_bitwidth = 8 , rank = 2 ), [32 , 32 ], 8 , 0 ),
13661365 (ttgl .NVMMADistributedLayout (version = [2 , 0 ], warps_per_cta = [2 , 4 ], instr_shape = [16 , 8 ]),
1367- ttgl .NVMMASharedLayout (swizzle_byte_width = 0 , element_bitwidth = 16 , rank = 2 ), [4 , 64 ], 16 , 4 ),
1366+ ttgl .NVMMASharedLayout (swizzle_byte_width = 0 , element_bitwidth = 16 , rank = 2 ), [4 , 64 ], 16 , 3 ),
13681367 (ttgl .NVMMADistributedLayout (version = [3 , 0 ], warps_per_cta = [4 , 1 ], instr_shape = [16 , 32 , 16 ]),
1369- ttgl .NVMMASharedLayout (swizzle_byte_width = 128 , element_bitwidth = 32 , rank = 2 ), [128 , 64 ], 32 , 2 ),
1368+ ttgl .NVMMASharedLayout (swizzle_byte_width = 128 , element_bitwidth = 32 , rank = 2 ), [128 , 64 ], 32 , 1 ),
13701369])
13711370def test_bank_conflicts (reg_layout , shared_layout , shape , bitwidth , ref_conflicts ):
13721371 dtype = {8 : ttgl .int8 , 16 : ttgl .float16 , 32 : ttgl .float32 }[bitwidth ]
0 commit comments