1+ import math
2+ import torch
13from dataclasses import dataclass
24import triton
35import triton .language as tl
@@ -12,24 +14,31 @@ class CDNA4MXScaleLayout(Layout):
1214
1315 def __init__ (self , shape ) -> None :
1416 super ().__init__ (shape )
17+ (
18+ * self .leading_shape ,
19+ self .K_SCALE ,
20+ self .N ,
21+ ) = shape
22+ self .B = math .prod (self .leading_shape )
23+ self .ALIGN_K_SCALE = 8
24+ self .ALIGN_N = 32
25+ self .K_SCALE_pad = math .ceil (self .K_SCALE / self .ALIGN_K_SCALE ) * self .ALIGN_K_SCALE
26+ self .N_pad = math .ceil (self .N / self .ALIGN_N ) * self .ALIGN_N
1527
1628 def swizzle_data (self , data ):
17- block_shape = data .shape
18- SCALE_K = block_shape [- 2 ]
19- N = block_shape [- 1 ]
29+ data = torch .nn .functional .pad (data , (0 , self .N_pad - self .N , 0 , self .K_SCALE_pad - self .K_SCALE ))
2030 data = data .transpose (- 1 , - 2 )
21- data = data .view (- 1 , N // NON_K_PRESHUFFLE_BLOCK_SIZE , 2 , 16 , SCALE_K // 8 , 2 , 4 , 1 )
31+ data = data .view (- 1 , self . N_pad // NON_K_PRESHUFFLE_BLOCK_SIZE , 2 , 16 , self . K_SCALE_pad // 8 , 2 , 4 , 1 )
2232 data = data .permute (0 , 1 , 4 , 6 , 3 , 5 , 2 , 7 ).contiguous ()
23- if len (block_shape ) == 3 :
24- E = block_shape [0 ]
25- data = data .reshape (E , N // 32 , SCALE_K * 32 )
26- else :
27- assert len (block_shape ) == 2
28- data = data .reshape (N // 32 , SCALE_K * 32 )
33+ data = data .reshape (self .B , self .N_pad // 32 , self .K_SCALE_pad * 32 )
2934 return data .transpose (- 1 , - 2 )
3035
3136 def unswizzle_data (self , data ):
32- raise NotImplementedError ()
37+ data = data .transpose (- 1 , - 2 )
38+ data = data .view (- 1 , self .N_pad // NON_K_PRESHUFFLE_BLOCK_SIZE , self .K_SCALE_pad // 8 , 4 , 16 , 2 , 2 , 1 )
39+ data = data .permute (0 , 1 , 6 , 4 , 2 , 5 , 3 , 7 )
40+ data = data .reshape (* self .leading_shape , self .N_pad , self .K_SCALE_pad )
41+ return data .transpose (- 1 , - 2 )[..., :self .K_SCALE , :self .N ]
3342
3443 def swizzle_block_shape (self , block_shape ):
3544 SCALE_K = block_shape [- 2 ]
0 commit comments