Skip to content

Commit 6dbd4ce

Browse files
authored
fix
1 parent d615f11 commit 6dbd4ce

File tree

2 files changed

+4
-1
lines changed

2 files changed

+4
-1
lines changed

lightllm/common/basemodel/layer_weights/meta_weights/mm_weight.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,15 +73,19 @@ def _post_load_weights(self) -> None:
7373
and (not self.static_activation or self.input_scale is not None)
7474
):
7575
if self.weight_scale.ndim > 1:
76+
# 让 k dim 更连续,大多数split k 算法的算子可能能更快
7677
self.weight_scale = self.weight_scale.cuda(self.device_id_).transpose(0, 1)
7778
self.weight = [
79+
# 让 k dim 更连续,大多数split k 算法的算子可能能更快
7880
self.weight.cuda(self.device_id_).transpose(0, 1),
7981
self.weight_scale,
8082
self.input_scale,
8183
]
8284
else:
8385
self.weight = self.quant_method.quantize(self.weight.to(self.data_type_).cuda(self.device_id_))
8486
return
87+
88+
# 让 k dim 更连续,大多数split k 算法的算子可能能更快
8589
self.weight = self.weight.to(self.data_type_).cuda(self.device_id_).transpose(0, 1)
8690

8791

lightllm/common/quantization/vllm_quant.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,6 @@ def apply(self, input_tensor, weights, bias=None, out=None, workspace=None, use_
198198
dtype=input_tensor.dtype,
199199
)
200200
else:
201-
# qweight = qweight.t().contiguous().t()
202201
input_scale = input_scale.t().contiguous().t()
203202
torch.ops._C.cutlass_scaled_mm(out, qinput_tensor, qweight, input_scale, weight_scale, bias)
204203
return out

0 commit comments

Comments
 (0)