Skip to content

Commit ff99639

Browse files
authored
Add register_fake for finegrained_mixed_dtype_gemm torch_op (NVIDIA#6255)
Signed-off-by: Daniel Afrimi <[email protected]>
1 parent 6007373 commit ff99639

File tree

1 file changed

+20
-0
lines changed

1 file changed

+20
-0
lines changed

tensorrt_llm/_torch/custom_ops/torch_custom_ops.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -851,6 +851,26 @@ def finegrained_mixed_dtype_gemm(
851851
**kwargs)
852852

853853

854+
@finegrained_mixed_dtype_gemm.register_fake
855+
def _(
856+
input: torch.Tensor,
857+
weight: torch.Tensor,
858+
scales: torch.Tensor,
859+
group_size: int,
860+
has_zero_point: bool,
861+
output_dtype: torch.dtype,
862+
alpha: Optional[float] = None,
863+
bias: Optional[torch.Tensor] = None,
864+
zeros: Optional[torch.Tensor] = None,
865+
) -> torch.Tensor:
866+
# For a typical GEMM: input [M, K] @ weight [K, N] -> output [M, N]
867+
# Weight is typically packed, so we need to infer the output dimension
868+
M = input.size(0)
869+
# Assuming weight is packed and the output dimension can be inferred from weight.size(1)
870+
N = weight.size(1) if weight.dim() > 1 else weight.size(0)
871+
return input.new_empty((M, N), dtype=output_dtype)
872+
873+
854874
@torch.library.custom_op("trtllm::attention", mutates_args=())
855875
def attention(
856876
q: torch.Tensor,

0 commit comments

Comments
 (0)