Skip to content

Commit cf9023a

Browse files
committed
Extract logit_cap from the args position not the value
Signed-off-by: nvchenghaoz <[email protected]>
1 parent 7228d98 commit cf9023a

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -425,7 +425,12 @@ def get_constants(cls, source_attn_node: Node) -> List[Constant]:
425425
ad_logger.warning("Provided scale is not a float. Using default scale instead.")
426426
scale = None
427427

428-
logit_cap = source_attn_node.kwargs.get("logit_cap", None)
428+
# Get logit_cap from args or kwargs - it's typically the 8th argument (index 7)
429+
if len(source_attn_node.args) > 7:
430+
logit_cap = source_attn_node.args[7]
431+
else:
432+
logit_cap = source_attn_node.kwargs.get("logit_cap", None)
433+
429434
if not (isinstance(logit_cap, float) or logit_cap is None):
430435
ad_logger.debug("Provided logit_cap is not a float or None. Disabling soft-capping.")
431436
logit_cap = None

0 commit comments

Comments
 (0)