Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion backends/vulkan/op_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,12 @@ def register_mm_op(features: OpFeatures):
return features


@update_features(exir_ops.edge.aten._weight_int8pack_mm.default)
@update_features(
[
exir_ops.edge.aten._weight_int8pack_mm.default,
exir_ops.edge.et_vk.linear_qcs4w.default,
]
)
def register_int8_mm_op(features: OpFeatures):
features.texture_impl = TextureImplFeatures(
uses_axis_map=False,
Expand Down
18 changes: 14 additions & 4 deletions backends/vulkan/runtime/graph/ops/glsl/indexing_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,22 +41,32 @@
/*
* Fast division by 4 using bit shifting
*/
#define div4(x) (x >> 2)
#define div4(x) ((x) >> 2)

/*
* Fast multiplication by 4 using bit shifting
*/
#define mul4(x) ((x) << 2)

/*
* Divides input and rounds up to 4
*/
#define divup4(x) ((x + 3) >> 2)
#define divup4(x) (((x) + 3) >> 2)

/*
* Divides input by denominator and rounds up
*/
#define divup(x, d) (((x) + (d) - 1) / (d))

/*
* Aligns input to the next multiple of 4
*/
#define alignup4(x) ((x + 3) & -4)
#define alignup4(x) (((x) + 3) & -4)

/*
* Fast modulo by 4 using bit masking
*/
#define mod4(x) (x & 3)
#define mod4(x) ((x) & 3)

/*
* Find the packed dimension of a tensor given its strides. The packed dimension
Expand Down
145 changes: 102 additions & 43 deletions backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_coop.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#define VEC4_T ${buffer_gvec_type(DTYPE, 4)}

#define TILE_ROWS ${TILE_ROWS}
#define TILE_TXCOLS ${TILE_TXCOLS}

#define NGROUPS 8
#define NWORKERS 8
Expand All @@ -29,7 +30,10 @@ layout(std430) buffer;

${layout_declare_tensor(B, "w", "t_out", DTYPE, OUT_STORAGE, is_scalar_array=False)}
${layout_declare_tensor(B, "r", "t_in", DTYPE, IN_STORAGE, is_scalar_array=False)}
${layout_declare_tensor(B, "r", "t_weight", "int8", WEIGHT_STORAGE, is_scalar_array=False)}
$if QUANT_NBITS == 4:
${layout_declare_tensor(B, "r", "t_weight", "uint8", WEIGHT_STORAGE, is_scalar_array=False)}
$else:
${layout_declare_tensor(B, "r", "t_weight", "int8", WEIGHT_STORAGE, is_scalar_array=False)}
${layout_declare_tensor(B, "r", "t_scales", DTYPE, SCALES_STORAGE, is_scalar_array=False)}

layout(push_constant) uniform restrict Block {
Expand All @@ -42,12 +46,23 @@ layout(push_constant) uniform restrict Block {

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

shared VEC4_T partial_c[NGROUPS][NWORKERS][TILE_ROWS];
shared VEC4_T partial_sums[NGROUPS][NWORKERS][TILE_ROWS][TILE_TXCOLS];

void main() {
const uint out_width_ntexels = divup4(out_sizes.x);
const uint out_col = (gl_GlobalInvocationID.x % out_width_ntexels) << 2;
const uint out_row = (gl_GlobalInvocationID.x / out_width_ntexels) * TILE_ROWS;
// txcol stands for "texel column". One txcol corresponds to 4 scalar columns.
$if TILE_TXCOLS > 1:
const uint global_wg_x = uint(divup(out_sizes.x, 4 * TILE_TXCOLS));
const uint out_txcol = uint(
(gl_GlobalInvocationID.x % global_wg_x) * TILE_TXCOLS);
$else:
const uint global_wg_x = uint(divup4(out_sizes.x));
const uint out_txcol = uint(gl_GlobalInvocationID.x % global_wg_x);

const uint out_row = uint(
(gl_GlobalInvocationID.x / global_wg_x) * TILE_ROWS);

$if QUANT_NBITS == 4:
const uint weight_txcol = uint(out_txcol / 2);

const int gid = int(gl_LocalInvocationID.x); // group id
const int wid = int(gl_LocalInvocationID.z); // worker id
Expand All @@ -56,46 +71,78 @@ void main() {
return;
}

VEC4_T a[TILE_ROWS];
VEC4_T b[4];
VEC4_T local_c[TILE_ROWS];
VEC4_T mat1[TILE_ROWS];
VEC4_T qmat2[4][TILE_TXCOLS];
VEC4_T local_sums[TILE_ROWS][TILE_TXCOLS];

[[unroll]] for (int i = 0; i < TILE_ROWS; ++i) {
local_c[i] = VEC4_T(0.0);
[[unroll]] for (int r = 0; r < TILE_ROWS; ++r) {
$for c in range(TILE_TXCOLS):
local_sums[r][${c}] = VEC4_T(0.0);
}

$if SCALES_STORAGE == "buffer":
const VEC4_T scales = VEC4_T(t_scales[out_col >> 2]);
$else:
const VEC4_T scales = VEC4_T(texelFetch(t_scales, ivec2(out_col >> 2, 0), 0));

for (int pos = 4 * wid; pos < in_sizes.x; pos += (4 * NWORKERS)) {
// Preload t_weight
[[unroll]] for (int i = 0; i < 4; i++) {
$if WEIGHT_STORAGE == "buffer":
b[i] = t_weight[((pos + i) * weight_sizes.x + out_col) >> 2];
VEC4_T scales[TILE_TXCOLS];
$for c in range(TILE_TXCOLS):
$if SCALES_STORAGE == "buffer":
scales[${c}] = VEC4_T(t_scales[out_txcol + ${c}]);
$else:
scales[${c}] = VEC4_T(
texelFetch(t_scales, ivec2(out_txcol + ${c}, 0), 0));

for (int pos = (4 * wid), txpos = wid;
pos < in_sizes.x;
pos += (4 * NWORKERS), txpos += NWORKERS) {
$if WEIGHT_STORAGE == "buffer":
uint qmat2_bufi;
uint weight_row_txstride = div4(weight_sizes.x);

// Preload weight tensor
[[unroll]] for (int r = 0; r < 4; r++) {
$if QUANT_NBITS == 4:
$for c in range(0, TILE_TXCOLS, 2):
$if WEIGHT_STORAGE == "buffer":
qmat2_bufi = (pos + r) * weight_row_txstride + weight_txcol;
const u8vec4 packed_weight_tex = t_weight[qmat2_bufi + ${c}]
$else:
const uvec4 packed_weight_tex = texelFetch(
t_weight, ivec2(weight_txcol + ${c}, pos + r), 0);

qmat2[r][${c}] = (VEC4_T((packed_weight_tex & 0xF0) >> 4) - 8.0);
qmat2[r][${c + 1}] = (VEC4_T(packed_weight_tex & 0x0F) - 8.0);
$else:
b[i] = VEC4_T(texelFetch(t_weight, ivec2(out_col >> 2, pos + i), 0));
$for c in range(TILE_TXCOLS):
$if WEIGHT_STORAGE == "buffer":
qmat2_bufi = (pos + r) * weight_row_txstride + out_txcol;
qmat2[r][${c}] = t_weight[qmat2_bufi + ${c}];
$else:
qmat2[r][${c}] = VEC4_T(
texelFetch(t_weight, ivec2(out_txcol + ${c}, pos + r), 0));
}
// Preload t_in
for (int i = 0; i < TILE_ROWS; i++) {

$if IN_STORAGE == "buffer":
uint in_row_txstride = div4(in_sizes.x);

// Preload input tensor
[[unroll]] for (int i = 0; i < TILE_ROWS; i++) {
$if IN_STORAGE == "buffer":
a[i] = t_in[((out_row + i) * in_sizes.x + pos) >> 2];
mat1[i] = t_in[(out_row + i) * in_row_txstride + txpos];
$else:
a[i] = VEC4_T(texelFetch(t_in, ivec3(pos >> 2, out_row + i, 0), 0));
mat1[i] = VEC4_T(
texelFetch(t_in, ivec3(txpos, out_row + i, 0), 0));
}

// Accumulate partial output
[[unroll]] for (int i = 0; i < TILE_ROWS; ++i) {
local_c[i] += a[i].x * b[0] +
a[i].y * b[1] +
a[i].z * b[2] +
a[i].w * b[3];
[[unroll]] for (int r = 0; r < TILE_ROWS; ++r) {
$for c in range(TILE_TXCOLS):
local_sums[r][${c}] += mat1[r].x * qmat2[0][${c}] +
mat1[r].y * qmat2[1][${c}] +
mat1[r].z * qmat2[2][${c}] +
mat1[r].w * qmat2[3][${c}];
}
}

[[unroll]] for (int i = 0; i < TILE_ROWS; ++i) {
partial_c[gid][wid][i] = local_c[i];
[[unroll]] for (int r = 0; r < TILE_ROWS; ++r) {
$for c in range(TILE_TXCOLS):
partial_sums[gid][wid][r][${c}] = local_sums[r][${c}];
}

memoryBarrierShared();
Expand All @@ -105,21 +152,33 @@ void main() {
return;
}

VEC4_T c[TILE_ROWS];
VEC4_T sums[TILE_ROWS][TILE_TXCOLS];

for (int r = 0; r < TILE_ROWS; ++r) {
$for c in range(TILE_TXCOLS):
sums[r][${c}] = VEC4_T(0.0);

for (int row = 0; row < TILE_ROWS; ++row) {
c[row] = VEC4_T(0.0);
[[unroll]] for (int worker = 0; worker < NWORKERS; ++worker) {
c[row] += partial_c[gid][worker][row];
$for c in range(TILE_TXCOLS):
sums[r][${c}] += partial_sums[gid][worker][r][${c}];
}
}

[[unroll]] for (int i = 0; i < TILE_ROWS; ++i) {
$if OUT_STORAGE == "buffer":
if (out_row + i < out_sizes.y) {
t_out[((out_row + i) * out_sizes.x + out_col) >> 2] = c[i] * scales;
}
$else:
imageStore(t_out, ivec3(out_col >> 2, out_row + i, 0), c[i] * scales);
$if OUT_STORAGE == "buffer":
uint out_bufi;
uint out_row_txstride = div4(out_sizes.x);

[[unroll]] for (int r = 0; r < TILE_ROWS; ++r) {
$for c in range(TILE_TXCOLS):
$if OUT_STORAGE == "buffer":
if (out_row + r < out_sizes.y) {
out_bufi = (out_row + r) * out_row_txstride + out_txcol;
t_out[out_bufi + ${c}] = sums[r][${c}] * scales[${c}];
}
$else:
imageStore(
t_out,
ivec3(out_txcol + ${c}, out_row + r, 0),
sums[r][${c}] * scales[${c}]);
}
}
2 changes: 2 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_coop.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ linear_qcsnw_coop:
WEIGHT_STORAGE: texture2d
SCALES_STORAGE: texture2d
TILE_ROWS: 4
TILE_TXCOLS: 1
QUANT_NBITS: 8
generate_variant_forall:
TILE_ROWS:
- VALUE: 1
Expand Down
Loading
Loading