@@ -1954,10 +1954,7 @@ def _import_hop_flex_attention(
19541954 score_mod_name = score_mod_arg .target
19551955 score_mod_module = getattr (root_module , score_mod_name , None )
19561956 if score_mod_module is not None :
1957- # The function was imported by _import_all_child_modules with this naming convention
1958- score_mod_func_name = (
1959- f"main_{ score_mod_name } _{ id (score_mod_module )} "
1960- )
1957+ score_mod_func_name = score_mod_name
19611958 score_mod_ref = FlatSymbolRefAttr .get (score_mod_func_name )
19621959
19631960 # Handle block_mask: extract mask_mod function and tensor components
@@ -1985,7 +1982,7 @@ def _import_hop_flex_attention(
19851982 # Check if it's a GraphModule (mask_mod) or a tensor
19861983 if isinstance (obj , GraphModule ):
19871984 # This is the mask_mod function
1988- mask_mod_func_name = f"main_ { component .target } _ { id ( obj ) } "
1985+ mask_mod_func_name = component .target
19891986 mask_mod_ref = FlatSymbolRefAttr .get (mask_mod_func_name )
19901987 else :
19911988 # It's a tensor (block indices)
@@ -2042,18 +2039,6 @@ def _import_hop_flex_attention(
20422039 # Single output
20432040 result_types = [self ._cc .node_val_to_type (node )]
20442041
2045- # Build operands list for aten.flex_attention
2046- # We'll pass tensors as operands and functions as attributes
2047- operands = [query , key , value ]
2048-
2049- # Add block_mask tensors if present
2050- operands .extend (block_mask_tensors )
2051-
2052- # Add scale and enable_gqa
2053- operands .append (scale )
2054- operands .append (enable_gqa_value )
2055-
2056- # Create aten.flex_attention op directly.
20572042 with loc :
20582043 return_lse = _make_constant_op (
20592044 "torch.constant.bool" ,
@@ -2119,15 +2104,13 @@ def _import_hop_flex_attention(
21192104 ]
21202105
21212106 # Build attributes with function references
2122- attributes = {}
2123- if score_mod_ref is not None :
2124- attributes ["score_mod_fn" ] = score_mod_ref
2125- if mask_mod_ref is not None :
2126- attributes ["mask_mod_fn" ] = mask_mod_ref
2127- if kv_block_size is not None :
2128- attributes ["kv_block_size" ] = self ._cc .integer_attr (kv_block_size , 64 )
2129- if q_block_size is not None :
2130- attributes ["q_block_size" ] = self ._cc .integer_attr (q_block_size , 64 )
2107+ attributes = {
2108+ "score_mod_fn" : score_mod_ref ,
2109+ "mask_mod_fn" : mask_mod_ref ,
2110+ "kv_block_size" : self ._cc .integer_attr (kv_block_size , 64 ),
2111+ "q_block_size" : self ._cc .integer_attr (q_block_size , 64 ),
2112+ }
2113+ attributes = {k : v for k , v in attributes .items () if v is not None }
21312114
21322115 operation = Operation .create (
21332116 "torch.aten.flex_attention" ,
0 commit comments