Skip to content

Commit bc71d30

Browse files
authored
Call fd.ops.sdpfa_fwd with kwargs (#2799)
1 parent 33d7edf commit bc71d30

File tree

1 file changed

+2
-7
lines changed

1 file changed

+2
-7
lines changed

thunder/executors/nvfuserex_impl.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2568,13 +2568,8 @@ def _scaled_dot_product_flash_attention_forward(
25682568
fd: FusionDefinition,
25692569
lc_to_nv_map: dict,
25702570
) -> Any:
2571-
inputs = [query, key, value, dropout_p, is_causal, scale]
2572-
nv_inputs = []
2573-
for inp in inputs:
2574-
nv_inp = getnv(inp, fd, lc_to_nv_map) if inp is not None else None
2575-
nv_inputs.append(nv_inp)
2576-
2577-
return fd.ops.sdpfa_fwd(*nv_inputs)
2571+
args = [getnv(arg, fd, lc_to_nv_map) for arg in (query, key, value)]
2572+
return fd.ops.sdpfa_fwd(*args, dropout_p=dropoutp, is_causal=is_causal, scale=scale)
25782573

25792574

25802575
nv_sdpfa_fwd = ex.register_operator(

0 commit comments

Comments
 (0)