diff --git a/attentionbench/attention_utils.py b/attentionbench/attention_utils.py index 7628061..3d9f07e 100644 --- a/attentionbench/attention_utils.py +++ b/attentionbench/attention_utils.py @@ -93,7 +93,7 @@ def get_translation_info(self) -> str: llvm_func_attrs += [f'"denormal-fp-math-f32" = "preserve-sign"'] return ( f"#iree_codegen.translation_info<" - + f"LLVMGPUVectorDistribute" + + f"pipeline = LLVMGPUVectorDistribute" + f" workgroup_size = [{self.N_warp * self.M_warp * 64}]" + f" subgroup_size = 64" + f" ,{{mma_schedule = {self.get_mma_schedule()}"