diff --git a/attentionbench/attention_utils.py b/attentionbench/attention_utils.py index 7628061..40b3f76 100644 --- a/attentionbench/attention_utils.py +++ b/attentionbench/attention_utils.py @@ -76,14 +76,15 @@ def get_lowering_config(self) -> str: + f">" ) - def get_mma_schedule(self) -> str: - return ( - f"#iree_gpu.mma_schedule<" - + f"intrinsic = #iree_gpu.mma_layout<{self.intrinsic}>" - + f", subgroup_m_count = {self.M_warp}" - + f", subgroup_n_count = {self.N_warp}" - + f">" - ) + def get_lowering_config_for_mmt(self, extra_args) -> str: + base_str = (f"#iree_gpu.lowering_config<{{" + + f"mma_kind = #iree_gpu.mma_layout<{self.intrinsic}>" + + f", subgroup_m_count = {self.M_warp}" + + f", subgroup_n_count = {self.N_warp}") + for arg in extra_args: + base_str += f", {arg}" + base_str += "}>" + return base_str def get_translation_info(self) -> str: llvm_func_attrs = [] @@ -93,11 +94,10 @@ def get_translation_info(self) -> str: llvm_func_attrs += [f'"denormal-fp-math-f32" = "preserve-sign"'] return ( f"#iree_codegen.translation_info<" - + f"LLVMGPUVectorDistribute" + + f"pipeline = LLVMGPUVectorDistribute" + f" workgroup_size = [{self.N_warp * self.M_warp * 64}]" + f" subgroup_size = 64" - + f" ,{{mma_schedule = {self.get_mma_schedule()}" - + f" , llvm_func_attrs = {{ {','.join(llvm_func_attrs)} }}" + + f" , {{llvm_func_attrs = {{ {','.join(llvm_func_attrs)} }}" + f"}}" + f">" ) @@ -139,8 +139,8 @@ def generate_mlir(config: AttentionConfig, tuning: Optional[TuningSpec] = None): %O = iree_linalg_ext.attention {{ indexing_maps = [#Q, #K, #V, #S, #O] ,decomposition_config = {{ - qk_attrs = {{attention_qk_matmul, lowering_config = #iree_gpu.lowering_config<{{promote_operands = [0, 1]}}>}}, - pv_attrs = {{attention_pv_matmul, lowering_config = #iree_gpu.lowering_config<{{promote_operands = [1]}}>}} + qk_attrs = {{attention_qk_matmul, lowering_config = {tuning.get_lowering_config_for_mmt(["promote_operands = [0, 1]"])}}}, + pv_attrs = {{attention_pv_matmul, lowering_config = {tuning.get_lowering_config_for_mmt(["promote_operands = [1]"])}}} }} {",compilation_info = #tuning" if tuning and config.dtype == "f16" else ""} }}