-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathkernel.py
More file actions
137 lines (110 loc) · 5.09 KB
/
kernel.py
File metadata and controls
137 lines (110 loc) · 5.09 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import torch
import numpy as np
def unpack_int4_from_int32(packed_tensor, original_shape):
"""
Unpack INT4 values from INT32 storage.
Each INT32 contains 8 INT4 values:
INT32: [31..28][27..24][23..20][19..16][15..12][11..8][7..4][3..0]
│ val7 ││ val6 ││ val5 ││ val4 ││ val3 ││ val2││val1││val0│
Args:
packed_tensor: torch.Tensor with dtype int32 containing packed INT4 values
original_shape: tuple, the original shape of the weight before packing
Returns:
torch.Tensor with unpacked INT4 values as float32
"""
# Convert to numpy for reliable bit operations
packed_np = packed_tensor.detach().cpu().numpy().astype(np.uint32)
# Extract 8 INT4 values from each UINT32
unpacked_list = []
for i in range(8):
shift = i * 4
mask = 0xF # 4-bit mask (0b1111)
int4_val = (packed_np >> shift) & mask
# Convert from unsigned [0,15] to signed [-8,7]
int4_signed = int4_val.astype(np.int8) - 8
unpacked_list.append(int4_signed)
# Stack along new axis: [..., 8]
unpacked_np = np.stack(unpacked_list, axis=-1)
# Flatten the last dimension: [..., 8] -> [..., 8*elements]
flat_shape = list(unpacked_np.shape[:-1]) + [-1]
unpacked_flat = unpacked_np.reshape(flat_shape)
# Calculate expected total elements
total_elements = 1
for dim in original_shape:
total_elements *= dim
# Handle potential padding
unpacked_1d = unpacked_flat.flatten()
if len(unpacked_1d) > total_elements:
unpacked_1d = unpacked_1d[:total_elements]
elif len(unpacked_1d) < total_elements:
# Pad if needed (shouldn't happen with correct packing)
padding = total_elements - len(unpacked_1d)
unpacked_1d = np.concatenate([unpacked_1d, np.zeros(padding, dtype=np.int8)])
# Reshape to original shape
result = unpacked_1d.reshape(original_shape)
return torch.from_numpy(result).float()
def apply_group_scaling(unpacked, scale_tensor, group_size=32):
"""
Apply group-wise scaling to unpacked INT4 values.
Args:
unpacked: torch.Tensor, shape [out_features, in_features]
scale_tensor: torch.Tensor, shape [out_features, in_features//group_size]
group_size: int, number of elements per group (default 32 for Kimi)
Returns:
torch.Tensor with scaled values
"""
if scale_tensor.numel() == 1:
# Single scale value
return unpacked * scale_tensor.item()
out_features, in_features = unpacked.shape
scale_out, scale_in = scale_tensor.shape
if scale_out == out_features and scale_in * group_size == in_features:
# Standard group-wise scaling
# Reshape: [out_features, in_features] -> [out_features, scale_in, group_size]
weight_grouped = unpacked.view(out_features, scale_in, group_size)
# Expand scale: [out_features, scale_in] -> [out_features, scale_in, 1]
scale_expanded = scale_tensor.view(out_features, scale_in, 1)
# Apply scaling
scaled_grouped = weight_grouped * scale_expanded.float()
# Reshape back: [out_features, scale_in, group_size] -> [out_features, in_features]
return scaled_grouped.view(out_features, in_features)
elif scale_out == out_features:
# Try to handle irregular group sizes
actual_group_size = in_features // scale_in
if actual_group_size > 0 and in_features % scale_in == 0:
weight_grouped = unpacked.view(out_features, scale_in, actual_group_size)
scale_expanded = scale_tensor.view(out_features, scale_in, 1)
scaled_grouped = weight_grouped * scale_expanded.float()
return scaled_grouped.view(out_features, in_features)
else:
# Fallback: repeat scales to match dimensions
scale_repeated = scale_tensor.repeat_interleave(
(in_features + scale_in - 1) // scale_in, dim=1
)[:, :in_features]
return unpacked * scale_repeated.float()
else:
# Last resort: try direct broadcasting
try:
return unpacked * scale_tensor.float()
except RuntimeError:
# Use mean scale as fallback
return unpacked * scale_tensor.mean().item()
def weight_dequant_int4(packed_tensor, scale_tensor, shape_tensor):
"""
Dequantize INT4 weights to BF16.
This is the INT4 equivalent of weight_dequant() in the FP8 version.
Args:
packed_tensor: torch.Tensor, INT32 tensor with packed INT4 values
scale_tensor: torch.Tensor, BF16 tensor with group-wise scales
shape_tensor: torch.Tensor, INT32 tensor with original shape [H, W]
Returns:
torch.Tensor in BF16 format with original shape
"""
# Get original shape
original_shape = tuple(shape_tensor.tolist())
# Step 1: Unpack INT4 values from INT32 storage
unpacked = unpack_int4_from_int32(packed_tensor, original_shape)
# Step 2: Apply group-wise scaling
scaled = apply_group_scaling(unpacked, scale_tensor, group_size=32)
# Step 3: Convert to BF16
return scaled.to(torch.bfloat16)