@@ -58,6 +58,7 @@ def get_flops(self) -> int:
5858@dataclass
5959class TuningSpec :
6060 wg_tiles : list [int ]
61+ reduction_tiles : list [int ]
6162 M_warp : int
6263 N_warp : int
6364 intrinsic : str
@@ -66,8 +67,11 @@ class TuningSpec:
6667
6768 def get_lowering_config (self ) -> str :
6869 return (
69- f"#iree_codegen.lowering_config<"
70- + f"tile_sizes = [[{ ',' .join ([str (x ) for x in self .wg_tiles ])} ]]"
70+ f"#iree_gpu.lowering_config<"
71+ + "{ "
72+ + f"workgroup = [{ ', ' .join (map (str , self .wg_tiles ))} ], "
73+ + f"reduction = [{ ', ' .join (map (str , self .reduction_tiles ))} ]"
74+ + " }"
7175 + f">"
7276 )
7377
@@ -145,7 +149,7 @@ def generate_mlir(config: AttentionConfig, tuning: Optional[TuningSpec] = None):
145149
146150
147151def get_attention_flags () -> list [str ]:
148- return []
152+ return ["--iree-codegen-gpu-native-math-precision" ]
149153
150154
151155def compile_attention_config (
@@ -157,7 +161,7 @@ def compile_attention_config(
157161
158162 # TODO: Use different tuning specs for different configs. This is just a
159163 # general tuning config that worked well for sdxl shapes.
160- spec = TuningSpec ([1 , 128 , 0 , 0 , 32 ], 4 , 1 , "MFMA_F32_32x32x8_F16" , 2 , True )
164+ spec = TuningSpec ([1 , 128 , 0 , 0 , 0 ], [ 0 , 0 , 0 , 0 , 32 ], 4 , 1 , "MFMA_F32_32x32x8_F16" , 2 , True )
161165 # Generate mlir content
162166 mlir_content = generate_mlir (config , spec )
163167
@@ -196,3 +200,9 @@ def compile_attention_config(
196200 return mlir_file , None
197201
198202 return mlir_file , vmfb_file
203+
204+ # Dummy test generation
205+ if __name__ == "__main__" :
206+ config = AttentionConfig (20 , 4096 , 64 , 64 , 4096 , "f16" )
207+ spec = TuningSpec ([1 , 128 , 0 , 0 , 0 ], [0 , 0 , 0 , 0 , 32 ], 4 , 1 , "MFMA_F32_32x32x8_F16" , 2 , True )
208+ print (generate_mlir (config , spec ))
0 commit comments