Skip to content

Commit 89336fb

Browse files
limin2021kaiyux
andauthored
[None][fix] Fix cute dsl nvfp4 gemm autotune issue (#8761)
Signed-off-by: Mindy Li <[email protected]> Co-authored-by: Kaiyu Xie <[email protected]>
1 parent f48968b commit 89336fb

File tree

1 file changed

+9
-0
lines changed

1 file changed

+9
-0
lines changed

tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,15 @@ def __init__(self, alpha: float, output_dtype: torch.dtype):
4949
f"SM version {get_sm_version()} is not supported for CuteDSLNVFP4BlackwellLinear, it only supports SM 100"
5050
)
5151

52+
# rewrite the hash function because the value of self.alpha doesn't affect the tactic.
53+
def __hash__(self):
54+
return hash((self.output_dtype, ))
55+
56+
def __eq__(self, other):
57+
if not isinstance(other, CuteDSLNVFP4BlackwellLinear):
58+
return False
59+
return self.output_dtype == other.output_dtype
60+
5261
def get_valid_tactics(
5362
self,
5463
inputs: List[torch.Tensor],

0 commit comments

Comments
 (0)