Skip to content

Commit 1e709a4

Browse files
Moved op definition to torchops
Signed-off-by: Keshav Vinayak Jha <[email protected]>
1 parent 5e024f6 commit 1e709a4

File tree

3 files changed

+56
-54
lines changed

3 files changed

+56
-54
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
*.code-workspace
77
.ipynb_checkpoints
88
*.venv/
9+
venv/
910
mlir_venv/
1011
externals/pytorch/
1112
libtorch*

include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td

Lines changed: 0 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -16194,60 +16194,6 @@ def Torch_AtenFloatScalarOp : Torch_Op<"aten.Float.Scalar", [
1619416194
let hasFolder = 1;
1619516195
}
1619616196

16197-
def Torch_AtenFlexAttentionOp : Torch_Op<"aten.flex_attention", [
16198-
AllowsTypeRefinement,
16199-
HasValueSemantics,
16200-
ReadOnly
16201-
]> {
16202-
let summary = "Generated op for `aten::flex_attention`";
16203-
let description = [{
16204-
FlexAttention operation with flexible block-sparse attention patterns.
16205-
16206-
Args:
16207-
query: Query tensor [B, H, M, K]
16208-
key: Key tensor [B, H, N, K]
16209-
value: Value tensor [B, H, N, Ev]
16210-
scale: Optional float for scaling attention scores (None means 1/sqrt(head_dim))
16211-
return_lse: Bool to return log-sum-exp values
16212-
16213-
Attributes:
16214-
score_mod_fn: Optional function symbol reference for score modification
16215-
mask_mod_fn: Optional function symbol reference for mask modification
16216-
16217-
# TODO: kernel_options: Dict attributes for performance tuning (block_size, num_warps, etc.)
16218-
16219-
Returns:
16220-
output: Result tensor [B, H, M, Ev]
16221-
logsumexp: Optional log-sum-exp tensor [B, H, M] (if return_lse=True)
16222-
}];
16223-
16224-
let arguments = (ins
16225-
AnyTorchTensorType:$query,
16226-
AnyTorchTensorType:$key,
16227-
AnyTorchTensorType:$value,
16228-
AnyTorchOptionalFloatType:$scale,
16229-
Torch_BoolType:$enable_gqa,
16230-
Torch_BoolType:$return_lse,
16231-
OptionalAttr<FlatSymbolRefAttr>:$score_mod_fn,
16232-
OptionalAttr<FlatSymbolRefAttr>:$mask_mod_fn
16233-
);
16234-
16235-
let results = (outs
16236-
AnyTorchTensorType:$output,
16237-
AnyTorchOptionalTensorType:$logsumexp
16238-
);
16239-
16240-
let hasCustomAssemblyFormat = 1;
16241-
let extraClassDefinition = [{
16242-
ParseResult AtenFlexAttentionOp::parse(OpAsmParser &parser, OperationState &result) {
16243-
return parseDefaultTorchOp(parser, result, 5, 2);
16244-
}
16245-
void AtenFlexAttentionOp::print(OpAsmPrinter &printer) {
16246-
printDefaultTorchOp(printer, *this, 5, 2);
16247-
}
16248-
}];
16249-
}
16250-
1625116197
def Torch_AtenFloatStrOp : Torch_Op<"aten.Float.str", [
1625216198
AllowsTypeRefinement,
1625316199
HasValueSemantics,

include/torch-mlir/Dialect/Torch/IR/TorchOps.td

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1442,4 +1442,59 @@ def Torch_OnnxVariantRotaryEmbeddingOp: Torch_Op<"onnx.rotary_embedding", [
14421442
let hasCustomAssemblyFormat = 1;
14431443
}
14441444

1445+
def Torch_AtenFlexAttentionOp : Torch_Op<"aten.flex_attention", [
1446+
AllowsTypeRefinement,
1447+
HasValueSemantics,
1448+
ReadOnly
1449+
]> {
1450+
let summary = "Generated op for `aten::flex_attention`";
1451+
let description = [{
1452+
FlexAttention operation with flexible block-sparse attention patterns.
1453+
1454+
Args:
1455+
query: Query tensor [B, H, M, K]
1456+
key: Key tensor [B, H, N, K]
1457+
value: Value tensor [B, H, N, Ev]
1458+
scale: Optional float for scaling attention scores (None means 1/sqrt(head_dim))
1459+
enable_gqa: Boolean for grouped query attention support
1460+
return_lse: Bool to return log-sum-exp values
1461+
1462+
Attributes:
1463+
score_mod_fn: Optional function symbol reference for score modification
1464+
mask_mod_fn: Optional function symbol reference for mask modification
1465+
1466+
TODO: kernel_options: Dict attributes for performance tuning (block_size, num_warps, etc.)
1467+
1468+
Returns:
1469+
output: Result tensor [B, H, M, Ev]
1470+
logsumexp: Optional log-sum-exp tensor [B, H, M] (if return_lse=True)
1471+
}];
1472+
1473+
let arguments = (ins
1474+
AnyTorchTensorType:$query,
1475+
AnyTorchTensorType:$key,
1476+
AnyTorchTensorType:$value,
1477+
AnyTorchOptionalFloatType:$scale,
1478+
Torch_BoolType:$enable_gqa,
1479+
Torch_BoolType:$return_lse,
1480+
OptionalAttr<FlatSymbolRefAttr>:$score_mod_fn,
1481+
OptionalAttr<FlatSymbolRefAttr>:$mask_mod_fn
1482+
);
1483+
1484+
let results = (outs
1485+
AnyTorchTensorType:$output,
1486+
AnyTorchOptionalTensorType:$logsumexp
1487+
);
1488+
1489+
let hasCustomAssemblyFormat = 1;
1490+
let extraClassDefinition = [{
1491+
ParseResult AtenFlexAttentionOp::parse(OpAsmParser &parser, OperationState &result) {
1492+
return parseDefaultTorchOp(parser, result, 6, 2);
1493+
}
1494+
void AtenFlexAttentionOp::print(OpAsmPrinter &printer) {
1495+
printDefaultTorchOp(printer, *this, 6, 2);
1496+
}
1497+
}];
1498+
}
1499+
14451500
#endif // TORCH_OPS

0 commit comments

Comments
 (0)