Skip to content

Commit d42c6ee

Browse files
committed
Use kwargs
1 parent fb989d4 commit d42c6ee

File tree

1 file changed

+4
-7
lines changed

1 file changed

+4
-7
lines changed

thunder/executors/nvfuserex_impl.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2568,13 +2568,10 @@ def _scaled_dot_product_flash_attention_forward(
25682568
fd: FusionDefinition,
25692569
lc_to_nv_map: dict,
25702570
) -> Any:
2571-
inputs = [query, key, value, dropout_p, is_causal, scale]
2572-
nv_inputs = []
2573-
for inp in inputs:
2574-
nv_inp = getnv(inp, fd, lc_to_nv_map) if inp is not None else None
2575-
nv_inputs.append(nv_inp)
2576-
2577-
return fd.ops.sdpfa_fwd(*nv_inputs)
2571+
args = [query, key, value]
2572+
for i, arg in enumerate(args):
2573+
args[i] = getnv(arg, fd, lc_to_nv_map)
2574+
return fd.ops.sdpfa_fwd(*args, dropout_p=dropoutp, is_causal=is_causal, scale=scale)
25782575

25792576

25802577
nv_sdpfa_fwd = ex.register_operator(

0 commit comments

Comments
 (0)