Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
c8c711c
Modified fx_importer to support hop_while_loop
keshavvinayak01 Oct 22, 2025
b250583
Addressed Comments | Simplified unique child_func_name creation
keshavvinayak01 Oct 23, 2025
db1e7e9
Addressed comments
keshavvinayak01 Oct 24, 2025
d9646c6
Formatting
keshavvinayak01 Oct 24, 2025
cc03291
Added children module imports to import_frozen_program flow
keshavvinayak01 Oct 24, 2025
6a70e1c
Formatting and reordered CHECKs
keshavvinayak01 Oct 24, 2025
85e3acd
Changes done to TorchToScf:
keshavvinayak01 Oct 24, 2025
e1ff87d
Added Control flow test
keshavvinayak01 Oct 27, 2025
558c7db
Cannot FX trace HOP
keshavvinayak01 Oct 28, 2025
39d5b24
Added flex_attention hop function
keshavvinayak01 Oct 28, 2025
dfdca75
Formatting
keshavvinayak01 Oct 28, 2025
6178d07
Fixed merge newline removals
keshavvinayak01 Oct 28, 2025
52f1fbc
Added AtenFluxAttentionOp
keshavvinayak01 Oct 29, 2025
a56433a
Added changes for correct functional references
keshavvinayak01 Oct 30, 2025
b0e8585
QOL changes:
keshavvinayak01 Nov 4, 2025
c34efab
Merge branch 'main' into keshavvinayak01/torch-aten-flex_attention
keshavvinayak01 Nov 4, 2025
4470978
Update fx_importer.py to remove deprecated note
keshavvinayak01 Nov 4, 2025
719fe5a
Clarify enable_gqa support in fx_importer.py
keshavvinayak01 Nov 4, 2025
5e024f6
Fix formatting in GeneratedTorchOps.td
keshavvinayak01 Nov 4, 2025
c78d699
return_lse is part of the kernel options
keshavvinayak01 Nov 6, 2025
da23ec9
Moved op definition to TorchOps.td
keshavvinayak01 Nov 7, 2025
af59413
Formatting TorchOps
keshavvinayak01 Nov 7, 2025
0103163
Added lit-test; Docs for FlexAttention
keshavvinayak01 Nov 7, 2025
48f12bc
Formatting
keshavvinayak01 Nov 7, 2025
ec3e5f8
Modified arg extraction
keshavvinayak01 Nov 10, 2025
fa5aba2
Removed enable_gqa from flex_attention; HOP does not accept that argu…
keshavvinayak01 Nov 12, 2025
2b0637c
Typo
keshavvinayak01 Nov 12, 2025
e7da0a7
Simplified arg extract logic
keshavvinayak01 Nov 13, 2025
53dd19a
return_lse should be booltype not i1
keshavvinayak01 Nov 13, 2025
de91ca2
Added basic_test for flex_attention
keshavvinayak01 Nov 14, 2025
47803e3
Formatting and allowed unused unpacked vals
keshavvinayak01 Nov 16, 2025
207621c
Added max_scores; changes to match pytorch naming conventions; Added …
keshavvinayak01 Nov 19, 2025
acc3ade
Corrected lit test
keshavvinayak01 Nov 19, 2025
16fc70c
Renamed aten.flex_attention -> hop_flex_attention; Added more lit tests
keshavvinayak01 Nov 20, 2025
9334c1a
Using direct calls to flex_attention for basic_tests; removed useless…
keshavvinayak01 Nov 24, 2025
cb1fbcd
Formatting
keshavvinayak01 Nov 24, 2025
f145056
Fix typos in comments for basic_test.py
keshavvinayak01 Nov 24, 2025
4ba9d8d
Added Verifier to HigherOrderFlexAttention operation
keshavvinayak01 Nov 25, 2025
9333028
Removed Dynamic head check (backend specific)
keshavvinayak01 Nov 27, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions include/torch-mlir/Dialect/Torch/IR/TorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1445,19 +1445,19 @@ def Torch_OnnxVariantRotaryEmbeddingOp: Torch_Op<"onnx.rotary_embedding", [
//===----------------------------------------------------------------------===//
// FlexAttention operation

// NOTE: This op is manually defined because `aten::flex_attention` exists in
// NOTE: This op is manually defined because flex_attention exists in
// PyTorch's Python API (torch.nn.attention.flex_attention) but is not yet
// registered in PyTorch's JIT operator registry. The update_torch_ods.sh script
// validates against the JIT registry, so it cannot auto-generate this op.
// Once PyTorch adds flex_attention to the JIT registry, this can be moved to
// the auto-generated section.
//===----------------------------------------------------------------------===//
def Torch_AtenFlexAttentionOp : Torch_Op<"aten.flex_attention", [
def Torch_HigherOrderFlexAttentionOp : Torch_Op<"hop_flex_attention", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::flex_attention`";
let summary = "Computes the flex_attention operation (1-1 with torch._higher_order_ops.flex_attention)";
let description = [{
FlexAttention operation with flexible block-sparse attention patterns.

Expand Down Expand Up @@ -1499,10 +1499,10 @@ def Torch_AtenFlexAttentionOp : Torch_Op<"aten.flex_attention", [

let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenFlexAttentionOp::parse(OpAsmParser &parser, OperationState &result) {
ParseResult HigherOrderFlexAttentionOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 6, 3);
}
void AtenFlexAttentionOp::print(OpAsmPrinter &printer) {
void HigherOrderFlexAttentionOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 6, 3);
}
}];
Expand Down
4 changes: 2 additions & 2 deletions python/torch_mlir/extras/fx_importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1918,7 +1918,7 @@ def _import_hop_flex_attention(
- kernel_options: Optional Dict of performance tuning options:
- return_lse: Boolean for whether to return the log-sum-exp tensor

This creates a call to aten.flex_attention with function symbol references for
This creates a call to hop_flex_attention with function symbol references for
score_mod and mask_mod.
"""
# flex_attention HOP args from PyTorch:
Expand Down Expand Up @@ -2035,7 +2035,7 @@ def _import_hop_flex_attention(
attributes["mask_mod_fn"] = mask_mod_ref

operation = Operation.create(
"torch.aten.flex_attention",
"torch.hop_flex_attention",
results=result_types,
operands=flat_operands,
attributes=attributes if attributes else None,
Expand Down
76 changes: 56 additions & 20 deletions test/Dialect/Torch/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -206,36 +206,72 @@ func.func @torch.aten.fake_quantize_per_tensor_affine.tensor_qparams (%arg0: !to
return %1 : !torch.vtensor<[3,3],f32>
}

// CHECK-LABEL: func.func @torch.aten.flex_attention
func.func @torch.aten.flex_attention (%arg0: !torch.vtensor<[2,4,8,16],f32>, %arg1: !torch.vtensor<[2,4,8,16],f32>, %arg2: !torch.vtensor<[2,4,8,16],f32>) -> (!torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8],f32>, !torch.vtensor<[2,4,8],f32>) {

//===----------------------------------------------------------------------===//
// FlexAttention variant tests
//===----------------------------------------------------------------------===//

func.func private @sdpa_score0(%arg0: !torch.vtensor<[],f32>, %arg1: !torch.vtensor<[],si32>, %arg2: !torch.vtensor<[],si32>, %arg3: !torch.vtensor<[],si32>, %arg4: !torch.vtensor<[],si32>) -> !torch.vtensor<[],f32> {
%5 = torch.aten.tanh %arg0 : !torch.vtensor<[],f32> -> !torch.vtensor<[],f32>
return %5 : !torch.vtensor<[],f32>
}

func.func private @sdpa_mask0(%arg0: !torch.vtensor<[],si32>, %arg1: !torch.vtensor<[],si32>, %arg2: !torch.vtensor<[],si32>, %arg3: !torch.vtensor<[],si32>) -> !torch.vtensor<[],i1> {
%0 = torch.aten.ge.Tensor %arg2, %arg3 : !torch.vtensor<[],si32>, !torch.vtensor<[],si32> -> !torch.vtensor<[],i1>
return %0 : !torch.vtensor<[],i1>
}

// CHECK-LABEL: func.func @torch.hop_flex_attention
func.func @torch.hop_flex_attention (%arg0: !torch.vtensor<[2,4,8,16],f32>, %arg1: !torch.vtensor<[2,4,8,16],f32>, %arg2: !torch.vtensor<[2,4,8,16],f32>) -> (!torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8],f32>, !torch.vtensor<[2,4,8],f32>) {
%float1.0 = torch.constant.float 1.000000e+00
%false_0 = torch.constant.bool false
// CHECK: %[[FLOAT:.*]] = torch.constant.float 1.000000e+00
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: torch.aten.flex_attention %arg0, %arg1, %arg2, %[[FLOAT]], %[[FALSE]], %[[FALSE]]
// CHECK: torch.hop_flex_attention %arg0, %arg1, %arg2, %[[FLOAT]], %[[FALSE]], %[[FALSE]]
// CHECK-SAME: {mask_mod_fn = @sdpa_mask0, score_mod_fn = @sdpa_score0}
// CHECK-SAME: : !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8,16],f32>, !torch.float, !torch.bool
// CHECK-SAME: -> !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8],f32>
%output, %logsumexp, %maxscore = torch.aten.flex_attention %arg0, %arg1, %arg2, %float1.0, %false_0, %false_0 {mask_mod_fn = @sdpa_mask0, score_mod_fn = @sdpa_score0} : !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8,16],f32>, !torch.float, !torch.bool, !torch.bool -> !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8],f32>, !torch.vtensor<[2,4,8],f32>
%output, %logsumexp, %maxscore = torch.hop_flex_attention %arg0, %arg1, %arg2, %float1.0, %false_0, %false_0 {mask_mod_fn = @sdpa_mask0, score_mod_fn = @sdpa_score0} : !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8,16],f32>, !torch.float, !torch.bool, !torch.bool -> !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8],f32>, !torch.vtensor<[2,4,8],f32>
return %output, %logsumexp, %maxscore : !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8],f32>, !torch.vtensor<[2,4,8],f32>
}

func.func private @sdpa_score0(%arg0: !torch.vtensor<[],f32>, %arg1: !torch.vtensor<[],si32>, %arg2: !torch.vtensor<[],si32>, %arg3: !torch.vtensor<[],si32>, %arg4: !torch.vtensor<[],si32>) -> !torch.vtensor<[],f32> {
%int1 = torch.constant.int 1
%0 = torch.aten.sub.Tensor %arg3, %arg4, %int1 : !torch.vtensor<[],si32>, !torch.vtensor<[],si32>, !torch.int -> !torch.vtensor<[],si32>
%float1.000000e-01 = torch.constant.float 1.000000e-01
%1 = torch.aten.mul.Scalar %arg2, %float1.000000e-01 : !torch.vtensor<[],si32>, !torch.float -> !torch.vtensor<[],f32>
%float1.000000e-02 = torch.constant.float 1.000000e-02
%2 = torch.aten.mul.Scalar %0, %float1.000000e-02 : !torch.vtensor<[],si32>, !torch.float -> !torch.vtensor<[],f32>
%int1_0 = torch.constant.int 1
%3 = torch.aten.add.Tensor %arg0, %2, %int1_0 : !torch.vtensor<[],f32>, !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[],f32>
%int1_1 = torch.constant.int 1
%4 = torch.aten.add.Tensor %3, %1, %int1_1 : !torch.vtensor<[],f32>, !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[],f32>
%5 = torch.aten.tanh %4 : !torch.vtensor<[],f32> -> !torch.vtensor<[],f32>
return %5 : !torch.vtensor<[],f32>
// CHECK-LABEL: func.func @torch.hop_flex_attention_nomask
func.func @torch.hop_flex_attention_nomask (%arg0: !torch.vtensor<[2,4,8,16],f32>, %arg1: !torch.vtensor<[2,4,8,16],f32>, %arg2: !torch.vtensor<[2,4,8,16],f32>) -> (!torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8],f32>, !torch.vtensor<[2,4,8],f32>) {
%float1.0 = torch.constant.float 1.000000e+00
%false_0 = torch.constant.bool false
// CHECK: %[[FLOAT:.*]] = torch.constant.float 1.000000e+00
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: torch.hop_flex_attention %arg0, %arg1, %arg2, %[[FLOAT]], %[[FALSE]], %[[FALSE]]
// CHECK-SAME: {score_mod_fn = @sdpa_score0}
// CHECK-SAME: : !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8,16],f32>, !torch.float, !torch.bool
// CHECK-SAME: -> !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8],f32>
%output, %logsumexp, %maxscore = torch.hop_flex_attention %arg0, %arg1, %arg2, %float1.0, %false_0, %false_0 {score_mod_fn = @sdpa_score0} : !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8,16],f32>, !torch.float, !torch.bool, !torch.bool -> !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8],f32>, !torch.vtensor<[2,4,8],f32>
return %output, %logsumexp, %maxscore : !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8],f32>, !torch.vtensor<[2,4,8],f32>
}

func.func private @sdpa_mask0(%arg0: !torch.vtensor<[],si32>, %arg1: !torch.vtensor<[],si32>, %arg2: !torch.vtensor<[],si32>, %arg3: !torch.vtensor<[],si32>) -> !torch.vtensor<[],i1> {
%0 = torch.aten.ge.Tensor %arg2, %arg3 : !torch.vtensor<[],si32>, !torch.vtensor<[],si32> -> !torch.vtensor<[],i1>
return %0 : !torch.vtensor<[],i1>
// CHECK-LABEL: func.func @torch.hop_flex_attention_noscore
func.func @torch.hop_flex_attention_noscore (%arg0: !torch.vtensor<[2,4,8,16],f32>, %arg1: !torch.vtensor<[2,4,8,16],f32>, %arg2: !torch.vtensor<[2,4,8,16],f32>) -> (!torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8],f32>, !torch.vtensor<[2,4,8],f32>) {
%float1.0 = torch.constant.float 1.000000e+00
%false_0 = torch.constant.bool false
// CHECK: %[[FLOAT:.*]] = torch.constant.float 1.000000e+00
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: torch.hop_flex_attention %arg0, %arg1, %arg2, %[[FLOAT]], %[[FALSE]], %[[FALSE]]
// CHECK-SAME: {mask_mod_fn = @sdpa_mask0}
// CHECK-SAME: : !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8,16],f32>, !torch.float, !torch.bool
// CHECK-SAME: -> !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8],f32>
%output, %logsumexp, %maxscore = torch.hop_flex_attention %arg0, %arg1, %arg2, %float1.0, %false_0, %false_0 {mask_mod_fn = @sdpa_mask0} : !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8,16],f32>, !torch.float, !torch.bool, !torch.bool -> !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8],f32>, !torch.vtensor<[2,4,8],f32>
return %output, %logsumexp, %maxscore : !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8],f32>, !torch.vtensor<[2,4,8],f32>
}

// CHECK-LABEL: func.func @torch.hop_flex_attention_noscore_nomask
func.func @torch.hop_flex_attention_noscore_nomask (%arg0: !torch.vtensor<[2,4,8,16],f32>, %arg1: !torch.vtensor<[2,4,8,16],f32>, %arg2: !torch.vtensor<[2,4,8,16],f32>) -> (!torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8],f32>, !torch.vtensor<[2,4,8],f32>) {
%float1.0 = torch.constant.float 1.000000e+00
%false_0 = torch.constant.bool false
// CHECK: %[[FLOAT:.*]] = torch.constant.float 1.000000e+00
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: torch.hop_flex_attention %arg0, %arg1, %arg2, %[[FLOAT]], %[[FALSE]], %[[FALSE]]
// CHECK-SAME: : !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8,16],f32>, !torch.float, !torch.bool
// CHECK-SAME: -> !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8],f32>
%output, %logsumexp, %maxscore = torch.hop_flex_attention %arg0, %arg1, %arg2, %float1.0, %false_0, %false_0 : !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8,16],f32>, !torch.float, !torch.bool, !torch.bool -> !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8],f32>, !torch.vtensor<[2,4,8],f32>
return %output, %logsumexp, %maxscore : !torch.vtensor<[2,4,8,16],f32>, !torch.vtensor<[2,4,8],f32>, !torch.vtensor<[2,4,8],f32>
}
2 changes: 1 addition & 1 deletion test/python/fx_importer/basic_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ def body(i, x):
# CHECK: %[[SCALE:.*]] = torch.constant.float 1.000000e+00
# CHECK: %[[RETURN_LSE:.*]] = torch.constant.bool false
# CHECK: %[[RETURN_MAX:.*]] = torch.constant.bool false
# CHECK: %[[OUTPUT:.*]], %[[LOGSUMEXP:.*]], %[[MAX_SCORES:.*]] = torch.aten.flex_attention %arg0, %arg1, %arg2, %[[SCALE]], %[[RETURN_LSE]], %[[RETURN_MAX]] {mask_mod_fn = @sdpa_mask0, score_mod_fn = @sdpa_score0}
# CHECK: %[[OUTPUT:.*]], %[[LOGSUMEXP:.*]], %[[MAX_SCORES:.*]] = torch.hop_flex_attention %arg0, %arg1, %arg2, %[[SCALE]], %[[RETURN_LSE]], %[[RETURN_MAX]] {mask_mod_fn = @sdpa_mask0, score_mod_fn = @sdpa_score0}
# CHECK-SAME: : !torch.vtensor<[4,8,1024,64],f32>, !torch.vtensor<[4,8,1024,64],f32>, !torch.vtensor<[4,8,1024,64],f32>, !torch.float, !torch.bool, !torch.bool
# CHECK-SAME: -> !torch.vtensor<[4,8,1024,64],f32>, !torch.vtensor<[4,8,1024],f32>, !torch.vtensor<[4,8,1024],f32>
# CHECK: return %[[OUTPUT]]
Expand Down
Loading