Skip to content

Commit 33365df

Browse files
committed
Update on "[ET-VK] Allow int4 linear to execute without 8bit buffer support"
## Context Some Vulkan devices do not have support for 8-bit buffers, which is currently required to execute the int4 linear compute shader due to the prepacking shader requiring it. This diff bypasses that restriction by introducing a variant of the prepacking shader that does not need 8-bit buffers. ## Changes Introduce a variant of the int4 weight prepacking shader that interprets the tensor data as an array of `uint` instead of `uint8_t`. Each `uint` represents 4 `uint8_t` values. Differential Revision: [D72750897](https://our.internmc.facebook.com/intern/diff/D72750897/) [ghstack-poisoned]
1 parent 20d5e66 commit 33365df

File tree

1 file changed

+44
-0
lines changed

1 file changed

+44
-0
lines changed

backends/vulkan/runtime/graph/ops/glsl/pack_int4_linear_weight_transposed_interleaved.glsl

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,50 @@ $if NO_INT8_BUFFERS:
5656
return (packed4 >> (idx * 8)) & 0xFF;
5757
}
5858

59+
/*
60+
* This shader packs the weight tensor into a texture.
61+
*
62+
* The original tensor has a (W, H) shape of (K / 2, N) and each scalar element
63+
* is a uint8_t, which contains 2 packed 4 bit uint values.
64+
*
65+
* The transform performed by this shader is to first transpose the tensor, so
66+
* the shape of the packed tensor becomes (N / 2, K). Then, the 4 bit integers
67+
* are re-packed in groups of 8. For each 4 uint8_t values, the "left" 4-bits
68+
* of each value contain the 0, 1, 2, 3 4-bit values, and the "right" 4-bits of
69+
* each value contain the 4, 5, 6, 7 4-bit values.
70+
*
71+
* As a concrete example, consider the following weight tensor. The | demarks
72+
* the packing boundary, so 1| 2 represents a single uint8_t value with 1 in the
73+
* leftmost 4 bits and 2 in the rightmost 4 bits.
74+
*
75+
* 1| 2, 3| 4, 5| 6, 7| 8,
76+
* 9|10, 11|12, 13|14, 15|16,
77+
* 17|18, 19|20, 21|22, 23|24,
78+
* 25|26, 27|28, 29|30, 31|32,
79+
* 33|34, 35|36, 37|38, 39|40,
80+
* 41|42, 43|44, 45|46, 47|48,
81+
* 49|50, 51|52, 53|54, 55|56,
82+
* 57|58, 59|60, 61|62, 63|64,
83+
*
84+
* After packing, the packed tensor would contain
85+
*
86+
* 1|33, 9|41, 17|49, 25|57,
87+
* 2|34, 10|42, 18|50, 26|58,
88+
* 3|35, 11|43, 19|51, 27|59,
89+
* 4|36, 12|44, 20|52, 28|60,
90+
* 5|37, 13|45, 21|53, 29|61,
91+
* 6|38, 14|46, 22|54, 30|62,
92+
* 7|39, 15|47, 23|55, 31|63,
93+
* 8|40, 16|48, 24|56, 32|64,
94+
*
95+
* The purpose of interleaving is to make it easier to extract the unpacked
96+
* values in order using the u8vec4 vectorized type. With the packing in place,
97+
* The 4-bit values can be extracted via
98+
*
99+
* u8vec4 packed;
100+
* u8vec4 vals_0123 = (packed & 0xF0) >> 4;
101+
* u8vec4 vals_4567 = (packed | 0x0F);
102+
*/
59103
void main() {
60104
// Each thread writes 2 output texels along the height axis
61105
ivec2 packed_pos = ivec2(

0 commit comments

Comments
 (0)