Skip to content

Commit 6c0c9d5

Browse files
committed
[ET-VK][ez] Add support for buffer backed qparams in int4 linear + add checks for physical limits when allocating
## Context Currently, the groupwise quantized int4 linear op implementation forces the scales and zero tensor to be a `Texture3D`. However, for i.e. transformer models that have a logit linear layer, the image extents required may exceed the maximum image extents available on the device. ## Changes * Add support for the scales and zero tensor being a `Buffer` instead of a `Texture3D` * Add checks when allocating buffers or images for tensors that the requested resource fits within the physical device limits Differential Revision: [D72662176](https://our.internmc.facebook.com/intern/diff/D72662176/) ghstack-source-id: 276858281 Pull Request resolved: #9974
1 parent a67b6b8 commit 6c0c9d5

File tree

5 files changed

+42
-13
lines changed

5 files changed

+42
-13
lines changed

backends/vulkan/runtime/api/containers/Tensor.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,11 @@ vkapi::VulkanImage allocate_image(
260260
return vkapi::VulkanImage();
261261
}
262262

263+
utils::uvec3 max_extents = adapter_ptr->max_texture_extents();
264+
VK_CHECK_COND(
265+
image_extents[0] <= max_extents[0] &&
266+
image_extents[1] <= max_extents[1] && image_extents[2] <= max_extents[2]);
267+
263268
VkSampler sampler = adapter_ptr->sampler_cache().retrieve(sampler_props);
264269

265270
return adapter_ptr->vma().create_image(
@@ -291,6 +296,8 @@ vkapi::VulkanBuffer allocate_buffer(
291296
return vkapi::VulkanBuffer();
292297
}
293298

299+
VK_CHECK_COND(numel <= context_ptr->adapter_ptr()->max_buffer_numel());
300+
294301
return adapter_ptr->vma().create_storage_buffer(
295302
element_size(dtype) * numel, allocate_memory);
296303
}

backends/vulkan/runtime/graph/ops/glsl/q_4w_linear.glsl

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ layout(std430) buffer;
2121
${layout_declare_tensor(B, "w", "t_out", DTYPE, OUT_STORAGE, is_scalar_array=False)}
2222
${layout_declare_tensor(B, "r", "t_mat1", DTYPE, IN_STORAGE, is_scalar_array=False)}
2323
${layout_declare_tensor(B, "r", "t_qmat2", "uint8", WEIGHT_STORAGE, is_scalar_array=False)}
24-
${layout_declare_tensor(B, "r", "t_qparams", DTYPE, "texture3D")}
24+
${layout_declare_tensor(B, "r", "t_qparams", DTYPE, PARAMS_STORAGE, is_scalar_array=False)}
2525

2626
layout(push_constant) uniform restrict Block {
2727
ivec4 out_sizes;
@@ -79,13 +79,23 @@ void main() {
7979

8080
$if WEIGHT_STORAGE == "buffer":
8181
const int qmat2_stride = qmat2_sizes.x >> 2;
82+
$if PARAMS_STORAGE == "buffer":
83+
const int qparams_y_stride = out_sizes.x >> 2;
84+
const int qparams_z_stride = qparams_y_stride * 2;
8285

8386
for (int block_idx = 0; block_idx < num_blocks; ++block_idx) {
84-
scales[0] = texelFetch(t_qparams, ivec3(out_col_texel_idx, 0, block_idx), 0);
85-
zeros[0] = texelFetch(t_qparams, ivec3(out_col_texel_idx, 1, block_idx), 0);
86-
87-
scales[1] = texelFetch(t_qparams, ivec3(out_col_texel_idx + 1, 0, block_idx), 0);
88-
zeros[1] = texelFetch(t_qparams, ivec3(out_col_texel_idx + 1, 1, block_idx), 0);
87+
$if PARAMS_STORAGE == "buffer":
88+
scales[0] = t_qparams[block_idx * qparams_z_stride + out_col_texel_idx];
89+
zeros[0] = t_qparams[block_idx * qparams_z_stride + out_col_texel_idx + qparams_y_stride];
90+
91+
scales[1] = t_qparams[block_idx * qparams_z_stride + out_col_texel_idx + 1];
92+
zeros[1] = t_qparams[block_idx * qparams_z_stride + out_col_texel_idx + 1 + qparams_y_stride];
93+
$else:
94+
scales[0] = texelFetch(t_qparams, ivec3(out_col_texel_idx, 0, block_idx), 0);
95+
zeros[0] = texelFetch(t_qparams, ivec3(out_col_texel_idx, 1, block_idx), 0);
96+
97+
scales[1] = texelFetch(t_qparams, ivec3(out_col_texel_idx + 1, 0, block_idx), 0);
98+
zeros[1] = texelFetch(t_qparams, ivec3(out_col_texel_idx + 1, 1, block_idx), 0);
8999

90100
for (int g_idx = 0; g_idx < group_size; g_idx += 4) {
91101
const int k = block_idx * group_size + g_idx;

backends/vulkan/runtime/graph/ops/glsl/q_4w_linear.yaml

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,13 @@ q_4w_linear:
1010
OUT_STORAGE: texture3d
1111
IN_STORAGE: texture3d
1212
WEIGHT_STORAGE: texture3d
13+
PARAMS_STORAGE: texture3d
1314
shader_variants:
14-
- NAME: q_4w_linear_texture3d_texture3d_texture3d_float
15-
- NAME: q_4w_linear_texture3d_buffer_texture3d_float
16-
IN_STORAGE: buffer
17-
- NAME: q_4w_linear_buffer_buffer_texture3d_float
15+
- NAME: q_4w_linear_texture3d_texture3d_texture3d_texture3d_float
16+
- NAME: q_4w_linear_buffer_buffer_texture3d_texture3d_float
1817
OUT_STORAGE: buffer
1918
IN_STORAGE: buffer
20-
- NAME: q_4w_linear_buffer_buffer_buffer_float
19+
- NAME: q_4w_linear_buffer_buffer_texture3d_buffer_float
2120
OUT_STORAGE: buffer
2221
IN_STORAGE: buffer
23-
WEIGHT_STORAGE: buffer
22+
PARAMS_STORAGE: buffer

backends/vulkan/runtime/graph/ops/impl/QuantizedLinearGroupwiseInt4.cpp

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,13 +132,22 @@ void add_q_4w_linear_node(
132132
ValueRef mat2 =
133133
prepack_int4_linear_weight_transposed_interleaved(graph, mat2_data);
134134

135+
utils::StorageType qparams_storage_type = utils::kTexture3D;
136+
utils::uvec3 max_extents =
137+
graph.context()->adapter_ptr()->max_texture_extents();
138+
if (graph.size_at<uint32_t>(-2, scales_and_zeros_data) > max_extents[0] * 4 ||
139+
graph.size_at<uint32_t>(-3, scales_and_zeros_data) > max_extents[2]) {
140+
qparams_storage_type = utils::kBuffer;
141+
}
142+
135143
ValueRef scales_and_zeros = prepack_standard_hw_transposed(
136-
graph, scales_and_zeros_data, utils::kTexture3D, utils::kWidthPacked);
144+
graph, scales_and_zeros_data, qparams_storage_type, utils::kWidthPacked);
137145

138146
std::string kernel_name = "q_4w_linear";
139147
add_storage_type_suffix(kernel_name, graph.storage_type_of(out));
140148
add_storage_type_suffix(kernel_name, graph.storage_type_of(mat1));
141149
add_storage_type_suffix(kernel_name, graph.storage_type_of(mat2));
150+
add_storage_type_suffix(kernel_name, qparams_storage_type);
142151
add_dtype_suffix(kernel_name, graph.dtype_of(out));
143152

144153
const uint32_t group_size_val = graph.extract_scalar<uint32_t>(group_size);

backends/vulkan/runtime/vk_api/Adapter.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,10 @@ class Adapter final {
218218
physical_device_.properties.limits.maxImageDimension3D};
219219
}
220220

221+
inline uint32_t max_buffer_numel() const {
222+
return physical_device_.properties.limits.maxStorageBufferRange;
223+
}
224+
221225
// Command Buffer Submission
222226

223227
void

0 commit comments

Comments
 (0)