Skip to content

Commit bcc4e69

Browse files
author
Doug Lehr
committed
Add triton gemm calls for unquantized gemms
1 parent 4d63faf commit bcc4e69

File tree

1 file changed

+11
-0
lines changed

1 file changed

+11
-0
lines changed

vllm/model_executor/layers/utils.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,14 @@ def aiter_GEMM_check(m, n, k):
104104
return False
105105

106106

107+
def aiter_GEMM_check(m, n, k):
108+
if ((n == 5120 and k == 2880)
109+
or (n == 2880 and k == 4096)
110+
or (n == 128 and k == 2880)
111+
or (n == 640 and k == 2880)
112+
or (n == 2880 and k == 512)):
113+
return True
114+
return False
107115

108116
def rocm_unquantized_gemm_impl(
109117
x: torch.Tensor,
@@ -118,6 +126,9 @@ def rocm_unquantized_gemm_impl(
118126
x.dtype in [torch.float16, torch.bfloat16] \
119127
and k % 8 == 0 and bias is None)
120128

129+
if VLLM_USE_AITER_TRITON_GEMM and aiter_GEMM_check(n, m, k):
130+
return gemm_a16w16(x, weight, bias)
131+
121132
if use_skinny is not True:
122133
return torch.nn.functional.linear(x, weight, bias)
123134

0 commit comments

Comments
 (0)