diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index c336660335a6..48814e5cb9b6 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -15193,6 +15193,43 @@ def Torch_AtenScaledDotProductAttentionOp : Torch_Op<"aten.scaled_dot_product_at }]; } +def Torch_Aten_ScaledDotProductFlashAttentionOp : Torch_Op<"aten._scaled_dot_product_flash_attention", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::_scaled_dot_product_flash_attention : (Tensor, Tensor, Tensor, float, bool, bool, float?) -> (Tensor, Tensor, Tensor, Tensor, int, int, Tensor, Tensor, Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$query, + AnyTorchTensorType:$key, + AnyTorchTensorType:$value, + Torch_FloatType:$dropout_p, + Torch_BoolType:$is_causal, + Torch_BoolType:$return_debug_mask, + AnyTorchOptionalFloatType:$scale + ); + let results = (outs + AnyTorchOptionalTensorType:$output, + AnyTorchOptionalTensorType:$logsumexp, + AnyTorchOptionalTensorType:$cum_seq_q, + AnyTorchOptionalTensorType:$cum_seq_k, + Torch_IntType:$max_q, + Torch_IntType:$max_k, + AnyTorchOptionalTensorType:$rng_state, + AnyTorchOptionalTensorType:$unused, + AnyTorchOptionalTensorType:$debug_attn_mask + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult Aten_ScaledDotProductFlashAttentionOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 7, 9); + } + void Aten_ScaledDotProductFlashAttentionOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 7, 9); + } + }]; +} + def Torch_AtenGridSamplerOp : Torch_Op<"aten.grid_sampler", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index df75d9427480..7bcd56a76c22 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -1087,9 +1087,14 @@ def emit_with_mutating_variants(key, **kwargs): "aten::upsample_bilinear2d : (Tensor, int[], bool, float?, float?) -> (Tensor)" ) emit("aten::upsample_bilinear2d.vec : (Tensor, int[]?, bool, float[]?) -> (Tensor)") + + # Attention ops. emit( "aten::scaled_dot_product_attention : (Tensor, Tensor, Tensor, Tensor?, float, bool, float?, bool) -> (Tensor)" ) + emit( + "aten::_scaled_dot_product_flash_attention : (Tensor, Tensor, Tensor, float, bool, bool, float?) -> (Tensor, Tensor, Tensor, Tensor, int, int, Tensor, Tensor, Tensor)" + ) emit("aten::grid_sampler : (Tensor, Tensor, int, int, bool) -> (Tensor)") emit( "aten::_trilinear : (Tensor, Tensor, Tensor, int[], int[], int[], int[], int) -> (Tensor)"