Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -15193,6 +15193,43 @@ def Torch_AtenScaledDotProductAttentionOp : Torch_Op<"aten.scaled_dot_product_at
}];
}

def Torch_AtenScaledDotProductFlashAttentionOp : Torch_Op<"aten._scaled_dot_product_flash_attention", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::_scaled_dot_product_flash_attention(Tensor, Tensor, Tensor, float, bool, bool, float?) -> (Tensor, Tensor, Tensor, Tensor, SymInt, SymInt, Tensor, Tensor, Tensor)`";
let arguments = (ins
AnyTorchTensorType:$query,
AnyTorchTensorType:$key,
AnyTorchTensorType:$value,
Torch_FloatType:$dropout_p,
Torch_BoolType:$is_causal,
Torch_BoolType:$return_debug_mask,
AnyTorchOptionalFloatType:$scale
);
let results = (outs
AnyTorchOptionalTensorType:$output,
AnyTorchOptionalTensorType:$logsumexp,
AnyTorchOptionalTensorType:$cum_seq_q,
AnyTorchOptionalTensorType:$cum_seq_k,
Torch_IntType:$max_q,
Torch_IntType:$max_k,
AnyTorchOptionalTensorType:$rng_state,
AnyTorchOptionalTensorType:$unused,
AnyTorchOptionalTensorType:$debug_attn_mask
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenScaledDotProductFlashAttentionOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 7, 9);
}
void AtenScaledDotProductFlashAttentionOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 7, 9);
}
}];
}

def Torch_AtenGridSamplerOp : Torch_Op<"aten.grid_sampler", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
"int[]": "AnyTorchListOfTorchIntType",
"int?": "AnyTorchOptionalIntType",
"int[]?": "AnyTorchOptionalListOfTorchIntType",
"SymInt": "Torch_IntType",
"bool": "Torch_BoolType",
"bool[]": "AnyTorchListOfTorchBoolType",
"bool?": "AnyTorchOptionalBoolType",
Expand Down Expand Up @@ -1087,9 +1088,14 @@ def emit_with_mutating_variants(key, **kwargs):
"aten::upsample_bilinear2d : (Tensor, int[], bool, float?, float?) -> (Tensor)"
)
emit("aten::upsample_bilinear2d.vec : (Tensor, int[]?, bool, float[]?) -> (Tensor)")

# Attention ops.
emit(
"aten::scaled_dot_product_attention : (Tensor, Tensor, Tensor, Tensor?, float, bool, float?, bool) -> (Tensor)"
)
emit(
"aten::_scaled_dot_product_flash_attention(Tensor, Tensor, Tensor, float, bool, bool, float?) -> (Tensor, Tensor, Tensor, Tensor, SymInt, SymInt, Tensor, Tensor, Tensor)"
)
emit("aten::grid_sampler : (Tensor, Tensor, int, int, bool) -> (Tensor)")
emit(
"aten::_trilinear : (Tensor, Tensor, Tensor, int[], int[], int[], int[], int) -> (Tensor)"
Expand Down
Loading