Skip to content

Commit a56433a

Browse files
Added changes for correct functional references
Signed-off-by: Keshav Vinayak Jha <[email protected]>
1 parent 52f1fbc commit a56433a

File tree

1 file changed

+9
-26
lines changed

1 file changed

+9
-26
lines changed

python/torch_mlir/extras/fx_importer.py

Lines changed: 9 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1954,10 +1954,7 @@ def _import_hop_flex_attention(
19541954
score_mod_name = score_mod_arg.target
19551955
score_mod_module = getattr(root_module, score_mod_name, None)
19561956
if score_mod_module is not None:
1957-
# The function was imported by _import_all_child_modules with this naming convention
1958-
score_mod_func_name = (
1959-
f"main_{score_mod_name}_{id(score_mod_module)}"
1960-
)
1957+
score_mod_func_name = score_mod_name
19611958
score_mod_ref = FlatSymbolRefAttr.get(score_mod_func_name)
19621959

19631960
# Handle block_mask: extract mask_mod function and tensor components
@@ -1985,7 +1982,7 @@ def _import_hop_flex_attention(
19851982
# Check if it's a GraphModule (mask_mod) or a tensor
19861983
if isinstance(obj, GraphModule):
19871984
# This is the mask_mod function
1988-
mask_mod_func_name = f"main_{component.target}_{id(obj)}"
1985+
mask_mod_func_name = component.target
19891986
mask_mod_ref = FlatSymbolRefAttr.get(mask_mod_func_name)
19901987
else:
19911988
# It's a tensor (block indices)
@@ -2042,18 +2039,6 @@ def _import_hop_flex_attention(
20422039
# Single output
20432040
result_types = [self._cc.node_val_to_type(node)]
20442041

2045-
# Build operands list for aten.flex_attention
2046-
# We'll pass tensors as operands and functions as attributes
2047-
operands = [query, key, value]
2048-
2049-
# Add block_mask tensors if present
2050-
operands.extend(block_mask_tensors)
2051-
2052-
# Add scale and enable_gqa
2053-
operands.append(scale)
2054-
operands.append(enable_gqa_value)
2055-
2056-
# Create aten.flex_attention op directly.
20572042
with loc:
20582043
return_lse = _make_constant_op(
20592044
"torch.constant.bool",
@@ -2119,15 +2104,13 @@ def _import_hop_flex_attention(
21192104
]
21202105

21212106
# Build attributes with function references
2122-
attributes = {}
2123-
if score_mod_ref is not None:
2124-
attributes["score_mod_fn"] = score_mod_ref
2125-
if mask_mod_ref is not None:
2126-
attributes["mask_mod_fn"] = mask_mod_ref
2127-
if kv_block_size is not None:
2128-
attributes["kv_block_size"] = self._cc.integer_attr(kv_block_size, 64)
2129-
if q_block_size is not None:
2130-
attributes["q_block_size"] = self._cc.integer_attr(q_block_size, 64)
2107+
attributes = {
2108+
"score_mod_fn": score_mod_ref,
2109+
"mask_mod_fn": mask_mod_ref,
2110+
"kv_block_size": self._cc.integer_attr(kv_block_size, 64),
2111+
"q_block_size": self._cc.integer_attr(q_block_size, 64),
2112+
}
2113+
attributes = {k: v for k, v in attributes.items() if v is not None}
21312114

21322115
operation = Operation.create(
21332116
"torch.aten.flex_attention",

0 commit comments

Comments
 (0)