Skip to content
Open
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -15194,6 +15194,64 @@ def Torch_AtenScaledDotProductAttentionOp : Torch_Op<"aten.scaled_dot_product_at
}];
}

def Torch_Aten_ScaledDotProductFlashAttentionOp : Torch_Op<"aten._scaled_dot_product_flash_attention", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::_scaled_dot_product_flash_attention : (Tensor, Tensor, Tensor, float, bool, bool, float?) -> (Tensor, Tensor, Tensor, Tensor, int, int, Tensor, Tensor, Tensor)`";
let arguments = (ins
AnyTorchTensorType:$query,
AnyTorchTensorType:$key,
AnyTorchTensorType:$value,
Torch_FloatType:$dropout_p,
Torch_BoolType:$is_causal,
Torch_BoolType:$return_debug_mask,
AnyTorchOptionalFloatType:$scale
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult Aten_ScaledDotProductFlashAttentionOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 7, 1);
}
void Aten_ScaledDotProductFlashAttentionOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 7, 1);
}
}];
}

def Torch_Aten_ScaledDotProductFlashAttentionForCpuOp : Torch_Op<"aten._scaled_dot_product_flash_attention_for_cpu", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::_scaled_dot_product_flash_attention_for_cpu : (Tensor, Tensor, Tensor, float, bool, bool, float?) -> (Tensor, Tensor, Tensor, Tensor, int, int, Tensor, Tensor, Tensor)`";
let arguments = (ins
AnyTorchTensorType:$query,
AnyTorchTensorType:$key,
AnyTorchTensorType:$value,
Torch_FloatType:$dropout_p,
Torch_BoolType:$is_causal,
Torch_BoolType:$return_debug_mask,
AnyTorchOptionalFloatType:$scale
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult Aten_ScaledDotProductFlashAttentionForCpuOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 7, 1);
}
void Aten_ScaledDotProductFlashAttentionForCpuOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 7, 1);
}
}];
}

def Torch_AtenGridSamplerOp : Torch_Op<"aten.grid_sampler", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
29 changes: 29 additions & 0 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4533,6 +4533,31 @@ class DecomposeAtenEluOp : public OpRewritePattern<AtenEluOp> {
};
} // namespace

// Decompose aten._scaled_dot_product_flash_attention/_for_cpu to
// aten.scaled_dot_product_attention
namespace {
template <typename OpTy>
class DecomposeScaledDotProductFlashAttentionOp
: public OpRewritePattern<OpTy> {
public:
using OpRewritePattern<OpTy>::OpRewritePattern;
LogicalResult matchAndRewrite(OpTy op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value constNone = ConstantNoneOp::create(rewriter, loc);
Value constTrue =
ConstantBoolOp::create(rewriter, loc, rewriter.getBoolAttr(true));

rewriter.replaceOpWithNewOp<AtenScaledDotProductAttentionOp>(
op, op.getType(), op.getQuery(), op.getKey(), op.getValue(),
/*attn_mask=*/constNone, op.getDropoutP(), op.getIsCausal(),
op.getScale(),
/*enable_gqa=*/constTrue);
return success();
}
};
} // namespace

// Selu = scale * (max(0,x) + min(0,alpha * (exp(x) − 1)))
namespace {
class DecomposeAtenSeluOp : public OpRewritePattern<AtenSeluOp> {
Expand Down Expand Up @@ -13324,6 +13349,10 @@ class DecomposeComplexOpsPass
addPatternIfTargetOpIsIllegal<DecomposeAtenNormalFunctionalOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenVarMeanOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenEluOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeScaledDotProductFlashAttentionOp<
Aten_ScaledDotProductFlashAttentionOp>>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeScaledDotProductFlashAttentionOp<
Aten_ScaledDotProductFlashAttentionForCpuOp>>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenFakeQuantizePerTensorAffineOp>(
patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenSeluOp>(patterns);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1088,9 +1088,17 @@ def emit_with_mutating_variants(key, **kwargs):
"aten::upsample_bilinear2d : (Tensor, int[], bool, float?, float?) -> (Tensor)"
)
emit("aten::upsample_bilinear2d.vec : (Tensor, int[]?, bool, float[]?) -> (Tensor)")

# Attention ops.
emit(
"aten::scaled_dot_product_attention : (Tensor, Tensor, Tensor, Tensor?, float, bool, float?, bool) -> (Tensor)"
)
emit(
"aten::_scaled_dot_product_flash_attention : (Tensor, Tensor, Tensor, float, bool, bool, float?) -> (Tensor, Tensor, Tensor, Tensor, int, int, Tensor, Tensor, Tensor)"
)
emit(
"aten::_scaled_dot_product_flash_attention_for_cpu : (Tensor, Tensor, Tensor, float, bool, bool, float?) -> (Tensor, Tensor, Tensor, Tensor, int, int, Tensor, Tensor, Tensor)"
)
emit("aten::grid_sampler : (Tensor, Tensor, int, int, bool) -> (Tensor)")
emit(
"aten::_trilinear : (Tensor, Tensor, Tensor, int[], int[], int[], int[], int) -> (Tensor)"
Expand Down
38 changes: 38 additions & 0 deletions test/Dialect/Torch/decompose-complex-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -950,3 +950,41 @@ func.func @channel_shuffle(%arg0: !torch.vtensor<[1,8,4,4],f32>) -> !torch.vtens
%0 = torch.aten.channel_shuffle %arg0, %int4 : !torch.vtensor<[1,8,4,4],f32>, !torch.int -> !torch.vtensor<[1,8,4,4],f32>
return %0 : !torch.vtensor<[1,8,4,4],f32>
}

// -----

// CHECK-LABEL: func.func @torch.aten._scaled_dot_product_flash_attention
func.func @torch.aten._scaled_dot_product_flash_attention(%arg0: !torch.vtensor<[32,8,128,64],f32>, %arg1: !torch.vtensor<[32,8,128,64],f32>, %arg2: !torch.vtensor<[32,8,128,64],f32>) -> !torch.vtensor<[32,8,128,64],f32> {
%float0 = torch.constant.float 0.000000e+00
%false = torch.constant.bool false
%scale = torch.constant.float 1.000000e+00
%output = torch.aten._scaled_dot_product_flash_attention %arg0, %arg1, %arg2, %float0, %false, %false, %scale: !torch.vtensor<[32,8,128,64],f32>, !torch.vtensor<[32,8,128,64],f32>, !torch.vtensor<[32,8,128,64],f32>, !torch.float, !torch.bool, !torch.bool, !torch.float -> !torch.vtensor<[32,8,128,64],f32>
return %output : !torch.vtensor<[32,8,128,64],f32>
// CHECK-SAME: %[[QUERY:.*]]: !torch.vtensor<[32,8,128,64],f32>, %[[KEY:.*]]: !torch.vtensor<[32,8,128,64],f32>, %[[VALUE:.*]]: !torch.vtensor<[32,8,128,64],f32>) -> !torch.vtensor<[32,8,128,64],f32> {
// CHECK-DAG: %[[NONE:.*]] = torch.constant.none
// CHECK-DAG: %[[FLOAT0:.*]] = torch.constant.float 0.000000e+00
// CHECK-DAG: %[[FALSE:.*]] = torch.constant.bool false
// CHECK-DAG: %[[SCALE:.*]] = torch.constant.float 1.000000e+00
// CHECK-DAG: %[[TRUE:.*]] = torch.constant.bool true
// CHECK: %[[RESULT:.*]] = torch.aten.scaled_dot_product_attention %[[QUERY]], %[[KEY]], %[[VALUE]], %[[NONE]], %[[FLOAT0]], %[[FALSE]], %[[SCALE]], %[[TRUE]]
// CHECK: return %[[RESULT]] : !torch.vtensor<[32,8,128,64],f32>
}

// -----

// CHECK-LABEL: func.func @torch.aten._scaled_dot_product_flash_attention_for_cpu
func.func @torch.aten._scaled_dot_product_flash_attention_for_cpu(%arg0: !torch.vtensor<[4,16,64,32],f16>, %arg1: !torch.vtensor<[4,16,64,32],f16>, %arg2: !torch.vtensor<[4,16,64,32],f16>) -> !torch.vtensor<[4,16,64,32],f16> {
%float0 = torch.constant.float 0.000000e+00
%false = torch.constant.bool false
%scale = torch.constant.float 1.000000e+00
%output = torch.aten._scaled_dot_product_flash_attention_for_cpu %arg0, %arg1, %arg2, %float0, %false, %false, %scale : !torch.vtensor<[4,16,64,32],f16>, !torch.vtensor<[4,16,64,32],f16>, !torch.vtensor<[4,16,64,32],f16>, !torch.float, !torch.bool, !torch.bool, !torch.float -> !torch.vtensor<[4,16,64,32],f16>
return %output : !torch.vtensor<[4,16,64,32],f16>
// CHECK-SAME: %[[QUERY:.*]]: !torch.vtensor<[4,16,64,32],f16>, %[[KEY:.*]]: !torch.vtensor<[4,16,64,32],f16>, %[[VALUE:.*]]: !torch.vtensor<[4,16,64,32],f16>) -> !torch.vtensor<[4,16,64,32],f16> {
// CHECK-DAG: %[[NONE:.*]] = torch.constant.none
// CHECK-DAG: %[[FLOAT0:.*]] = torch.constant.float 0.000000e+00
// CHECK-DAG: %[[FALSE:.*]] = torch.constant.bool false
// CHECK-DAG: %[[SCALE:.*]] = torch.constant.float 1.000000e+00
// CHECK-DAG: %[[TRUE:.*]] = torch.constant.bool true
// CHECK: %[[RESULT:.*]] = torch.aten.scaled_dot_product_attention %[[QUERY]], %[[KEY]], %[[VALUE]], %[[NONE]], %[[FLOAT0]], %[[FALSE]], %[[SCALE]], %[[TRUE]]
// CHECK: return %[[RESULT]] : !torch.vtensor<[4,16,64,32],f16>
}
Loading