Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 53 additions & 9 deletions ggml/src/ggml-vulkan/ggml-vulkan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,8 @@ struct vk_device_struct {
bool subgroup_ballot;
bool subgroup_clustered;
bool multi_add;
bool shader_int64;
bool buffer_device_address;

bool add_rms_fusion;
uint32_t partials_binding_alignment;
Expand Down Expand Up @@ -653,6 +655,7 @@ struct vk_buffer_struct {
vk::MemoryPropertyFlags memory_property_flags;
void * ptr;
size_t size = 0;
vk::DeviceAddress bda_addr {};

vk_device device;

Expand Down Expand Up @@ -985,6 +988,7 @@ struct vk_op_argsort_push_constants {
};

struct vk_op_im2col_push_constants {
uint64_t dst_addr;
uint32_t batch_offset; uint32_t offset_delta;
uint32_t IC;
uint32_t IW; uint32_t IH;
Expand All @@ -998,6 +1002,7 @@ struct vk_op_im2col_push_constants {
};

struct vk_op_im2col_3d_push_constants {
uint64_t dst_addr;
uint32_t nb10;
uint32_t nb11;
uint32_t nb12;
Expand Down Expand Up @@ -2010,10 +2015,17 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std
return buf;
}

vk::BufferUsageFlags usage_flags = vk::BufferUsageFlagBits::eStorageBuffer | vk::BufferUsageFlagBits::eTransferSrc | vk::BufferUsageFlagBits::eTransferDst;
vk::MemoryAllocateFlags mem_flags {};
if (device->buffer_device_address) {
usage_flags |= vk::BufferUsageFlagBits::eShaderDeviceAddress;
mem_flags |= vk::MemoryAllocateFlagBits::eDeviceAddress;
}

vk::BufferCreateInfo buffer_create_info{
vk::BufferCreateFlags(),
size,
vk::BufferUsageFlagBits::eStorageBuffer | vk::BufferUsageFlagBits::eTransferSrc | vk::BufferUsageFlagBits::eTransferDst,
usage_flags,
vk::SharingMode::eExclusive,
0,
nullptr,
Expand All @@ -2025,6 +2037,8 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std

vk::PhysicalDeviceMemoryProperties mem_props = device->physical_device.getMemoryProperties();

const vk::MemoryAllocateFlagsInfo mem_flags_info { mem_flags };

for (auto it = req_flags_list.begin(); it != req_flags_list.end(); it++) {
const auto & req_flags = *it;

Expand All @@ -2036,7 +2050,7 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std
buf->memory_property_flags = req_flags;

try {
buf->device_memory = device->device.allocateMemory({ mem_req.size, memory_type_index });
buf->device_memory = device->device.allocateMemory({ mem_req.size, memory_type_index, &mem_flags_info });
break;
} catch (const vk::SystemError& e) {
// loop and retry
Expand Down Expand Up @@ -2064,6 +2078,11 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std
buf->device = device;
buf->size = size;

if (device->buffer_device_address) {
const vk::BufferDeviceAddressInfo addressInfo(buf->buffer);
buf->bda_addr = device->device.getBufferAddress(addressInfo);
}

#ifdef GGML_VULKAN_MEMORY_DEBUG
device->memory_logger->log_allocation(buf, size);
#endif
Expand Down Expand Up @@ -3530,14 +3549,20 @@ static void ggml_vk_load_shaders(vk_device& device) {

ggml_vk_create_pipeline(device, device->pipeline_count_equal_i32, "count_equal_i32", count_equal_i32_len, count_equal_i32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, { device->subgroup_size }, 1);

ggml_vk_create_pipeline(device, device->pipeline_im2col_f32, "im2col_f32", im2col_f32_len, im2col_f32_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32, "im2col_3d_f32", im2col_3d_f32_len, im2col_3d_f32_data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true);
if (device->float_controls_rte_fp16) {
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_rte_len, im2col_f32_f16_rte_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32_f16, "im2col_3d_f32_f16", im2col_3d_f32_f16_rte_len, im2col_3d_f32_f16_rte_data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true);
#define IM2COL(bda) \
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32, "im2col_f32", im2col_f32 ## bda ## _len, im2col_f32 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); \
ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32, "im2col_3d_f32", im2col_3d_f32 ## bda ## _len, im2col_3d_f32 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true); \
if (device->float_controls_rte_fp16) { \
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_rte ## bda ## _len, im2col_f32_f16_rte ## bda ## _data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); \
ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32_f16, "im2col_3d_f32_f16", im2col_3d_f32_f16_rte ## bda ## _len, im2col_3d_f32_f16_rte ## bda ## _data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true); \
} else { \
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16 ## bda ## _len, im2col_f32_f16 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); \
ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32_f16, "im2col_3d_f32_f16", im2col_3d_f32_f16 ## bda ## _len, im2col_3d_f32_f16 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true); \
}
if (device->shader_int64 && device->buffer_device_address) {
IM2COL(_bda)
} else {
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_len, im2col_f32_f16_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32_f16, "im2col_3d_f32_f16", im2col_3d_f32_f16_len, im2col_3d_f32_f16_data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true);
IM2COL()
}

ggml_vk_create_pipeline(device, device->pipeline_timestep_embedding_f32, "timestep_embedding_f32", timestep_embedding_f32_len, timestep_embedding_f32_data, "main", 2, sizeof(vk_op_timestep_embedding_push_constants), {256, 1, 1}, {}, 1);
Expand Down Expand Up @@ -4015,6 +4040,9 @@ static vk_device ggml_vk_get_device(size_t idx) {
device->vendor_id != VK_VENDOR_ID_INTEL &&
getenv("GGML_VK_DISABLE_MULTI_ADD") == nullptr;

device->shader_int64 = device_features2.features.shaderInt64;
device->buffer_device_address = vk12_features.bufferDeviceAddress;

if (device->subgroup_size_control) {
device->subgroup_min_size = subgroup_size_control_props.minSubgroupSize;
device->subgroup_max_size = subgroup_size_control_props.maxSubgroupSize;
Expand Down Expand Up @@ -8592,6 +8620,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co

ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, subbuf_z, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
} else if (op == GGML_OP_IM2COL || op == GGML_OP_IM2COL_3D) {
if (ctx->device->shader_int64 && ctx->device->buffer_device_address) {
// buffer device address path doesn't use dst buffer
d_sz = 1;
}
// im2col uses only src1 and dst buffers
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
} else if (op == GGML_OP_COUNT_EQUAL) {
Expand Down Expand Up @@ -9443,7 +9475,13 @@ static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, co

const uint32_t pelements = OW * KW * KH;

const ggml_backend_vk_buffer_context * d_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
const vk_buffer d_buf = d_buf_ctx->dev_buffer;

const vk::DeviceAddress dst_addr = d_buf->bda_addr + vk_tensor_offset(dst) + dst->view_offs;

ggml_vk_op_f32<vk_op_im2col_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_IM2COL, {
dst_addr,
batch_offset, offset_delta,
IC, IW, IH, OW, OH, KW, KH,
pelements,
Expand Down Expand Up @@ -9479,8 +9517,14 @@ static void ggml_vk_im2col_3d(ggml_backend_vk_context * ctx, vk_context& subctx,
const int64_t OH = ne2;
const int64_t OW = ne1;

const ggml_backend_vk_buffer_context * d_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
const vk_buffer d_buf = d_buf_ctx->dev_buffer;

const vk::DeviceAddress dst_addr = d_buf->bda_addr + vk_tensor_offset(dst) + dst->view_offs;

vk_op_im2col_3d_push_constants pc {};

pc.dst_addr = dst_addr;
pc.nb10 = nb10 / ggml_type_size(src1->type);
pc.nb11 = nb11 / ggml_type_size(src1->type);
pc.nb12 = nb12 / ggml_type_size(src1->type);
Expand Down
21 changes: 15 additions & 6 deletions ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,11 @@

#include "rte.comp"

#include "types.comp"

layout (push_constant) uniform parameter
{
BDA_STORAGE_T dst_addr;
uint batch_offset; uint offset_delta;
uint IC;
uint IW; uint IH;
Expand All @@ -19,8 +22,6 @@ layout (push_constant) uniform parameter
int d0; int d1;
} p;

#include "types.comp"

layout(constant_id = 0) const uint BLOCK_SIZE = 32;

const uint NUM_ITER = 512 / BLOCK_SIZE;
Expand All @@ -30,6 +31,10 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};

#if BDA
layout (buffer_reference) buffer D_ptr {D_TYPE d;};
#endif

void main() {
const uint gidx = gl_GlobalInvocationID.x;

Expand All @@ -38,7 +43,7 @@ void main() {
const uint ic = gl_GlobalInvocationID.z % p.IC;

const uint src_base = ic * p.offset_delta + batch * p.batch_offset;
const uint dst_base = ((batch * p.OH + oh) * p.OW) * p.CHW + ic * (p.KW * p.KH);
const BDA_OFFSET_T dst_base = ((BDA_OFFSET_T(batch) * p.OH + oh) * p.OW) * p.CHW + BDA_OFFSET_T(ic) * (p.KW * p.KH);
const int oh_s1 = int(oh) * p.s1;
const uint ksize = p.OW * p.KH;

Expand All @@ -50,7 +55,7 @@ void main() {
uint current_ix = rem % p.OW;

A_TYPE values[NUM_ITER];
uint offset_dst[NUM_ITER];
BDA_OFFSET_T offset_dst[NUM_ITER];
[[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) {
values[idx] = A_TYPE(0);
}
Expand All @@ -66,7 +71,7 @@ void main() {
const uint iiw = current_ix * p.s0 + current_kx * p.d0 - p.p0;
const uint iih = oh_s1 + current_ky * p.d1 - p.p1;

offset_dst[idx] = dst_base + current_ix * p.CHW + current_ky * p.KW + current_kx;
offset_dst[idx] = dst_base + BDA_OFFSET_T(current_ix) * p.CHW + current_ky * p.KW + current_kx;

if ((iih < p.IH) && (iiw < p.IW)) {
values[idx] = data_a[src_base + iih * p.IW + iiw];
Expand All @@ -89,7 +94,11 @@ void main() {
continue;
}

#if BDA
D_ptr dst_addr = D_ptr(p.dst_addr + D_SIZE * offset_dst[idx]);
dst_addr.d = D_TYPE(values[idx]);
#else
data_d[offset_dst[idx]] = D_TYPE(values[idx]);
#endif
}

}
22 changes: 18 additions & 4 deletions ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,11 @@

#include "rte.comp"

#include "types.comp"

layout (push_constant) uniform parameter
{
BDA_STORAGE_T dst_addr;
uint32_t nb10;
uint32_t nb11;
uint32_t nb12;
Expand Down Expand Up @@ -38,8 +41,6 @@ layout (push_constant) uniform parameter
uint32_t misalign_offsets;
} p;

#include "types.comp"

uint get_aoffset() { return p.misalign_offsets >> 16; }
uint get_doffset() { return p.misalign_offsets & 0xFFFF; }

Expand All @@ -50,6 +51,10 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};

#if BDA
layout (buffer_reference) buffer D_ptr {D_TYPE d;};
#endif

void main() {
const uint32_t i = gl_GlobalInvocationID.x;

Expand Down Expand Up @@ -100,13 +105,22 @@ void main() {
const uint32_t iih = ioh * s1 + ikh * d1 - p1;
const uint32_t iid = iod * s2 + ikd * d2 - p2;

const uint32_t offset_dst = in_*OD_OH_OW_IC_KD_KH_KW + iod*OH_OW_IC_KD_KH_KW + ioh*OW_IC_KD_KH_KW + iow*IC_KD_KH_KW + iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw;
const BDA_OFFSET_T offset_dst = BDA_OFFSET_T(in_)*OD_OH_OW_IC_KD_KH_KW + BDA_OFFSET_T(iod)*OH_OW_IC_KD_KH_KW + BDA_OFFSET_T(ioh)*OW_IC_KD_KH_KW + BDA_OFFSET_T(iow)*IC_KD_KH_KW + iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw;

const uint32_t offset_src = (in_*IC + iic)*nb13 + iid*nb12 + iih*nb11 + iiw*nb10;
#if BDA
D_ptr dst_addr = D_ptr(p.dst_addr + D_SIZE * offset_dst);
if (iih >= IH || iiw >= IW || iid >= ID) {
dst_addr.d = D_TYPE(0.0f);
} else {
dst_addr.d = D_TYPE(data_a[offset_src + get_aoffset()]);
}
#else
if (iih >= IH || iiw >= IW || iid >= ID) {
data_d[offset_dst + get_doffset()] = D_TYPE(0.0f);
} else {
const uint32_t offset_src = (in_*IC + iic)*nb13 + iid*nb12 + iih*nb11 + iiw*nb10;
data_d[offset_dst + get_doffset()] = D_TYPE(data_a[offset_src + get_aoffset()]);
}
#endif
}
}
15 changes: 15 additions & 0 deletions ggml/src/ggml-vulkan/vulkan-shaders/types.comp
Original file line number Diff line number Diff line change
Expand Up @@ -1447,4 +1447,19 @@ float e8m0_to_fp32(uint8_t x) {
return uintBitsToFloat(bits);
}

#if BDA

#extension GL_EXT_buffer_reference : enable
#extension GL_EXT_shader_explicit_arithmetic_types_int64 : enable

#define BDA_STORAGE_T uint64_t
#define BDA_OFFSET_T uint64_t

#else

#define BDA_STORAGE_T uvec2
#define BDA_OFFSET_T uint

#endif

#endif // !defined(GGML_TYPES_COMP)
16 changes: 9 additions & 7 deletions ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -775,13 +775,15 @@ void process_shaders() {
string_to_spv("sum_rows_f32", "sum_rows.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
string_to_spv("count_equal_i32", "count_equal.comp", merge_maps(base_dict, {{"A_TYPE", "int"}, {"B_TYPE", "int"}, {"D_TYPE", "int"}}));

string_to_spv("im2col_f32", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
string_to_spv("im2col_f32_f16", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}}));
string_to_spv("im2col_f32_f16_rte", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}}));

string_to_spv("im2col_3d_f32", "im2col_3d.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
string_to_spv("im2col_3d_f32_f16", "im2col_3d.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}}));
string_to_spv("im2col_3d_f32_f16_rte", "im2col_3d.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}}));
for (std::string dim_str : {"", "_3d"}) {
for (bool bda : {false, true}) {
std::string bda_str = bda ? "_bda" : "";
std::string bda_def = bda ? "1" : "0";
string_to_spv("im2col" + dim_str + "_f32" + bda_str, "im2col" + dim_str + ".comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"D_SIZE", "4"}, {"BDA", bda_def}}));
string_to_spv("im2col" + dim_str + "_f32_f16" + bda_str, "im2col" + dim_str + ".comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"D_SIZE", "2"}, {"BDA", bda_def}}));
string_to_spv("im2col" + dim_str + "_f32_f16_rte" + bda_str, "im2col" + dim_str + ".comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"D_SIZE", "2"}, {"RTE16", "1"}, {"BDA", bda_def}}));
}
}

string_to_spv("timestep_embedding_f32", "timestep_embedding.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));

Expand Down
7 changes: 7 additions & 0 deletions tests/test-backend-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5753,6 +5753,13 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
}
}

#if 0
// >4GB im2col destination. Too slow to run by default.
// Test cases taken from Wan2.1 T2V 1.3B.
test_cases.emplace_back(new test_im2col (GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32, {832, 480, 192, 4}, {3, 3, 192, 96}, 1, 1, 1, 1, 1, 1, true));
test_cases.emplace_back(new test_im2col_3d(GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32, {834, 482, 6, 96}, {3, 3,3, 9216}, 96, 1, 1, 1, 0, 0, 0, 1, 1, 1, false));
#endif

// im2col 1D
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32, {3000, 128, 1, 1}, {3, 128, 1280, 1}, 1, 0, 1, 0, 1, 0, false));
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32, {3000, 128, 1, 1}, {3, 128, 1280, 1}, 1, 0, 1, 0, 1, 0, false));
Expand Down
Loading