@@ -1922,14 +1922,19 @@ def _import_hop_flex_attention(
19221922 ):
19231923 """Imports the torch._higher_order_ops.flex_attention HOP.
19241924
1925- Args format: (query, key, value, score_mod, block_mask, scale, kernel_options, ...)
1926- The score_mod is a submodule/callable that has been imported as a private function.
1927- The block_mask is a tuple: (kv_num_blocks, kv_indices, ..., mask_mod)
1928-
1929- This creates a call to aten.flex_attention with function symbol references.
1925+ Args format: (query, key, value, score_mod, block_mask, scale, enable_gqa, kernel_options, ...)
1926+ - query, key, value: Attention input tensors
1927+ - score_mod: Optional submodule/callable for score modification (imported as function)
1928+ - block_mask: Optional BlockMask tuple containing mask_mod function and runtime tensors
1929+ - scale: Optional float for attention score scaling
1930+ - enable_gqa: Boolean for grouped query attention support (TODO: NYI)
1931+ - kernel_options: Dict of performance tuning options (TODO: NYI)
1932+
1933+ This creates a call to aten.flex_attention with function symbol references for
1934+ score_mod and mask_mod.
19301935 """
19311936 # flex_attention HOP args from PyTorch:
1932- # (query, key, value, score_mod, block_mask, scale, kernel_options, return_lse_tuple, ...)
1937+ # (query, key, value, score_mod, block_mask, scale, enable_gqa, kernel_options, return_lse_tuple, ...)
19331938 if len (node .args ) < 6 :
19341939 raise ValueError (
19351940 f"flex_attention expects at least 6 arguments, got { len (node .args )} "
@@ -1938,68 +1943,51 @@ def _import_hop_flex_attention(
19381943 query_arg , key_arg , value_arg , score_mod_arg , block_mask_arg , scale_arg = (
19391944 node .args [:6 ]
19401945 )
1941- kernel_options = node .args [6 ] if len (node .args ) > 6 else {}
1946+
1947+ # TODO: Add support for enable_gqa (grouped query attention)
1948+ # This is a boolean flag that enables GQA optimization
1949+ enable_gqa = node .args [6 ] if len (node .args ) > 6 else False
1950+
1951+ # TODO: Add support for kernel_options (performance tuning parameters)
1952+ # This is a dict containing options like block sizes, num_warps, etc.
1953+ kernel_options = node .args [7 ] if len (node .args ) > 7 else {}
19421954
19431955 # Import Q, K, V tensors
19441956 query = self ._import_argument (loc , query_arg , None )
19451957 key = self ._import_argument (loc , key_arg , None )
19461958 value = self ._import_argument (loc , value_arg , None )
19471959
1948- # Handle score_mod: extract function reference from submodule
19491960 score_mod_ref = None
19501961 if score_mod_arg is not None and isinstance (score_mod_arg , torch_fx .Node ):
1951- # score_mod is a GraphModule reference from get_attr
1962+ assert (
1963+ score_mod_arg .op == "get_attr"
1964+ ), f"Expected get_attr for score_mod, got { score_mod_arg .op } "
19521965 root_module = node .graph .owning_module
1953- if hasattr (score_mod_arg , "target" ):
1954- score_mod_name = score_mod_arg .target
1955- score_mod_module = getattr (root_module , score_mod_name , None )
1956- if score_mod_module is not None :
1957- score_mod_func_name = score_mod_name
1958- score_mod_ref = FlatSymbolRefAttr .get (score_mod_func_name )
1959-
1960- # Handle block_mask: extract mask_mod function and tensor components
1961- # block_mask tuple format: (kv_num_blocks, kv_indices, q_num_blocks, q_indices,
1962- # kv_block_size, q_block_size, ..., mask_mod)
1963- mask_mod_ref = None
1964- block_mask_tensors = []
1965- kv_block_size = None
1966- q_block_size = None
1966+ score_mod_module = getattr (root_module , score_mod_arg .target , None )
1967+ if score_mod_module is not None :
1968+ score_mod_func_name = self .fx_importer ._graph_module_to_func_name [
1969+ id (score_mod_module )
1970+ ]
1971+ score_mod_ref = FlatSymbolRefAttr .get (score_mod_func_name )
19671972
1973+ # Handle block_mask: extract only mask_mod function reference
1974+ # Note: BlockMask contains runtime tensors (kv_num_blocks, kv_indices, etc.)
1975+ # that are materialized by evaluating mask_mod(b, h, q_idx, kv_idx).
1976+ mask_mod_ref = None
19681977 if block_mask_arg is not None and isinstance (block_mask_arg , tuple ):
1969- # Parse the block_mask tuple structure
1970- # First two entries: kv_num_blocks (int), kv_indices (tensor)
1971- # Next two: q_num_blocks (tensor), q_indices (tensor)
1972- # Then: scalar dimensions and the mask_mod function at the end
19731978 root_module = node .graph .owning_module
1974-
1975- for i , component in enumerate (block_mask_arg ):
1976- if isinstance (component , torch_fx .Node ):
1977- # Check if it's a tensor or a submodule reference
1978- if component .op == "get_attr" and hasattr (
1979- root_module , component .target
1980- ):
1981- obj = getattr (root_module , component .target )
1982- # Check if it's a GraphModule (mask_mod) or a tensor
1983- if isinstance (obj , GraphModule ):
1984- # This is the mask_mod function
1985- mask_mod_func_name = component .target
1986- mask_mod_ref = FlatSymbolRefAttr .get (mask_mod_func_name )
1987- else :
1988- # It's a tensor (block indices)
1989- block_mask_tensors .append (
1990- self ._import_argument (loc , component , None )
1991- )
1992- else :
1993- # Regular tensor argument
1994- block_mask_tensors .append (
1995- self ._import_argument (loc , component , None )
1996- )
1997- elif isinstance (component , int ):
1998- # Scalar dimensions (KV_BLOCK_SIZE, Q_BLOCK_SIZE)
1999- if kv_block_size is None :
2000- kv_block_size = component
2001- elif q_block_size is None :
2002- q_block_size = component
1979+ # The mask_mod function is the last element in the BlockMask tuple
1980+ mask_mod_arg = block_mask_arg [- 1 ]
1981+ if mask_mod_arg is not None and isinstance (mask_mod_arg , torch_fx .Node ):
1982+ assert (
1983+ mask_mod_arg .op == "get_attr"
1984+ ), f"Expected get_attr for mask_mod, got { mask_mod_arg .op } "
1985+ mask_mod_module = getattr (root_module , mask_mod_arg .target , None )
1986+ if mask_mod_module is not None :
1987+ mask_mod_func_name = self .fx_importer ._graph_module_to_func_name [
1988+ id (mask_mod_module )
1989+ ]
1990+ mask_mod_ref = FlatSymbolRefAttr .get (mask_mod_func_name )
20031991
20041992 # Import scale (float or None)
20051993 if scale_arg is None :
@@ -2018,17 +2006,6 @@ def _import_hop_flex_attention(
20182006 else :
20192007 scale = self ._import_argument (loc , scale_arg , None )
20202008
2021- # Get enable_gqa from kernel_options if present
2022- enable_gqa = False
2023- if isinstance (kernel_options , dict ) and "enable_gqa" in kernel_options :
2024- enable_gqa = kernel_options ["enable_gqa" ]
2025- with loc :
2026- enable_gqa_value = _make_constant_op (
2027- "torch.constant.bool" ,
2028- self ._cc .integer_attr (1 if enable_gqa else 0 , 1 ),
2029- self ._cc .torch_bool_type ,
2030- ).result
2031-
20322009 # Determine result types from node metadata
20332010 node_val = node .meta .get ("val" )
20342011 if isinstance (node_val , (list , tuple )) and len (node_val ) >= 2 :
@@ -2039,6 +2016,13 @@ def _import_hop_flex_attention(
20392016 # Single output
20402017 result_types = [self ._cc .node_val_to_type (node )]
20412018
2019+ with loc :
2020+ enable_gqa_value = _make_constant_op (
2021+ "torch.constant.bool" ,
2022+ self ._cc .integer_attr (1 if enable_gqa else 0 , 1 ),
2023+ self ._cc .torch_bool_type ,
2024+ ).result
2025+
20422026 with loc :
20432027 return_lse = _make_constant_op (
20442028 "torch.constant.bool" ,
@@ -2059,58 +2043,27 @@ def _import_hop_flex_attention(
20592043 self ._cc .torch_bool_type ,
20602044 ).result
20612045
2062- # Build operands for aten.flex_attention
2063- # Note: score_mod and block_mask function references go as ATTRIBUTES, not operands
2064-
2065- # Handle block_mask: wrap tensors in a list construct if present
2066- if block_mask_tensors :
2067- # Wrap block_mask tensors in torch.prim.ListConstruct
2068- block_mask_list = Operation .create (
2069- "torch.prim.ListConstruct" ,
2070- results = [IrType .parse ("!torch.list<vtensor>" , context = self ._c )],
2071- operands = block_mask_tensors ,
2072- loc = loc ,
2073- ).result
2074- else :
2075- # No block mask, use None
2076- block_mask_list = Operation .create (
2077- "torch.constant.none" ,
2078- results = [self ._cc .torch_none_type ],
2079- loc = loc ,
2080- ).result
2046+ # Build operands for aten.flex_attention.
2047+ # Op expects exactly 5 operands: query, key, value, scale, return_lse.
2048+ # Note: score_mod_fn and mask_mod_fn go as ATTRIBUTES, not operands.
2049+ # Note: block_mask tensors are handled by mask_mod_fn, not passed as operands.
20812050
20822051 flat_operands = [
20832052 query ,
20842053 key ,
20852054 value ,
2086- # score_mod placeholder (None)
2087- Operation .create (
2088- "torch.constant.none" ,
2089- results = [self ._cc .torch_none_type ],
2090- loc = loc ,
2091- ).result ,
2092- # block_mask as single list operand
2093- block_mask_list ,
20942055 scale ,
20952056 enable_gqa_value ,
2096- # Kernel options as None
2097- Operation .create (
2098- "torch.constant.none" ,
2099- results = [self ._cc .torch_none_type ],
2100- loc = loc ,
2101- ).result ,
2102- # return_lse
21032057 return_lse ,
21042058 ]
21052059
21062060 # Build attributes with function references
2107- attributes = {
2108- "score_mod_fn" : score_mod_ref ,
2109- "mask_mod_fn" : mask_mod_ref ,
2110- "kv_block_size" : self ._cc .integer_attr (kv_block_size , 64 ),
2111- "q_block_size" : self ._cc .integer_attr (q_block_size , 64 ),
2112- }
2113- attributes = {k : v for k , v in attributes .items () if v is not None }
2061+ # Only include attributes if they're not None (OptionalAttr in TableGen)
2062+ attributes = {}
2063+ if score_mod_ref is not None :
2064+ attributes ["score_mod_fn" ] = score_mod_ref
2065+ if mask_mod_ref is not None :
2066+ attributes ["mask_mod_fn" ] = mask_mod_ref
21142067
21152068 operation = Operation .create (
21162069 "torch.aten.flex_attention" ,
0 commit comments