@@ -77,16 +77,14 @@ def _flash_mha_fwd_hlo_lowering(ctx, q, k, v, softmax_scale=None, is_causal=Fals
77
77
v_type = ir .RankedTensorType (v .type )
78
78
v_shape = v_type .shape
79
79
80
- assert q_type .element_type == k_type .element_type
81
- assert q_type .element_type == v_type .element_type
80
+ assert q_type .element_type == k_type .element_type , "Q and K must have the same dtype"
81
+ assert q_type .element_type == v_type .element_type , "Q and V must have the same dtype"
82
82
element_type = q_type .element_type
83
- assert type (element_type ) in [ir .F16Type , ir .BF16Type ]
83
+ assert type (element_type ) in [ir .F16Type , ir .BF16Type ], "Only support fp16 and bf16 data type"
84
84
[n , l , h , d ] = q_shape
85
85
[nk , lk , hk , dk ] = k_shape
86
-
87
-
88
- assert k_shape == v_shape
89
- assert [n , d ] == [nk , dk ]
86
+ assert k_shape == v_shape , "K and V must have the same shape"
87
+ assert [n , d ] == [nk , dk ], "Q and K must have the same batch size and head size"
90
88
91
89
opaque = flash_api .make_flash_mha_fwd_args (
92
90
0.0 , # p_dropout
@@ -100,47 +98,39 @@ def _flash_mha_fwd_hlo_lowering(ctx, q, k, v, softmax_scale=None, is_causal=Fals
100
98
flash_api .BF16 if type (element_type ) == ir .BF16Type else flash_api .FP16 ,
101
99
0 )
102
100
103
- lse_type = ir .RankedTensorType .get ([n , h , l ], ir .F32Type .get (ctx .module_context .context ))
104
-
105
- if d % 8 != 0 :
106
- # We need padding. It's better to let xla's allocator handle it here than directly call cudaMalloc.
107
- dpad = 8 - d % 8
108
-
109
- z = np .array (0.0 , dtype = ir_type_to_dtype (element_type ))
110
- z = mlir .ir_constant (z )
111
- q_padded = mlir .hlo .PadOp (q ,z ,[0 ,0 ,0 ,0 ],[0 ,0 ,0 ,dpad ],[0 ,0 ,0 ,0 ]).result
112
- k_padded = mlir .hlo .PadOp (k ,z ,[0 ,0 ,0 ,0 ],[0 ,0 ,0 ,dpad ],[0 ,0 ,0 ,0 ]).result
113
- v_padded = mlir .hlo .PadOp (v ,z ,[0 ,0 ,0 ,0 ],[0 ,0 ,0 ,dpad ],[0 ,0 ,0 ,0 ]).result
114
-
101
+ def fwd (q , k , v ):
102
+ dpad = (8 - d % 8 ) % 8
103
+ if dpad > 0 :
104
+ # We need padding. It's better to let xla's allocator handle it here than directly call cudaMalloc.
105
+ q = jnp .pad (q , ((0 ,0 ),(0 ,0 ),(0 ,0 ),(0 ,dpad )), 'constant' )
106
+ k = jnp .pad (k , ((0 ,0 ),(0 ,0 ),(0 ,0 ),(0 ,dpad )), 'constant' )
107
+ v = jnp .pad (v , ((0 ,0 ),(0 ,0 ),(0 ,0 ),(0 ,dpad )), 'constant' )
108
+
115
109
q_shape = [n , l , h , d + dpad ]
116
110
k_shape = [n , lk , hk , d + dpad ]
117
111
v_shape = [n , lk , hk , d + dpad ]
118
112
o_shape = [n , l , h , d + dpad ]
113
+ lse_shape = [n , h , l ]
119
114
115
+
116
+ lse_type = ir .RankedTensorType .get ([n , h , l ], mlir .dtype_to_ir_type (jnp .float32 .dtype ))
120
117
out_types = [ir .RankedTensorType .get (o_shape , element_type ), lse_type ]
118
+ operand_layouts = default_layouts (q_shape , k_shape , v_shape )
119
+ result_layouts = default_layouts (o_shape , lse_shape )
121
120
122
- (o , lse ) = mlir .custom_call (
123
- b"flash_mha_fwd" ,
121
+ o , lse = custom_call (
122
+ q , k , v ,
123
+ call_target_name = b"flash_mha_fwd" ,
124
124
result_types = out_types ,
125
- operands = [q_padded , k_padded , v_padded ],
126
125
backend_config = opaque ,
127
- operand_layouts = default_layouts ( q_shape , k_shape , v_shape ) ,
128
- result_layouts = default_layouts ( * [ o . shape for o in out_types ]) ,
129
- ). results
126
+ operand_layouts = operand_layouts ,
127
+ result_layouts = result_layouts ,
128
+ )
130
129
131
- o = mlir .hlo .SliceOp (o , [0 ,0 ,0 ,0 ], (n , l , h , d ), [1 ,1 ,1 ,1 ]).result
132
- return (o ,lse )
133
- else :
134
- out_types = [ir .RankedTensorType .get ([n , l , h , d ], element_type ), lse_type ]
135
- out = mlir .custom_call (
136
- b"flash_mha_fwd" ,
137
- result_types = out_types ,
138
- operands = [q , k , v ],
139
- backend_config = opaque ,
140
- operand_layouts = default_layouts (q_shape , k_shape , v_shape ),
141
- result_layouts = default_layouts (* [o .shape for o in out_types ]),
142
- ).results
143
- return out
130
+ if dpad > 0 :
131
+ o = o [:,:,:,:d ]
132
+ return o , lse
133
+ return mlir .lower_fun (fwd , multiple_results = True )(ctx , q , k , v )
144
134
145
135
mlir .register_lowering (
146
136
_flash_mha_fwd_hlo_p ,
0 commit comments