Skip to content

Commit c21c8cd

Browse files
committed
Merge branch 'upstream' into concedo_experimental
2 parents 2f645bb + 7604a7d commit c21c8cd

File tree

2 files changed

+10
-3
lines changed

2 files changed

+10
-3
lines changed

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3192,7 +3192,7 @@ kernel void kernel_flash_attn_ext(
31923192

31933193
{
31943194
float S[Q] = { [0 ... Q-1] = 0.0f };
3195-
float M[Q] = { [0 ... Q-1] = -__FLT16_MAX__/2 };
3195+
float M[Q] = { [0 ... Q-1] = -__FLT_MAX__/2 };
31963196

31973197
// thread indices inside the simdgroup
31983198
// TODO: see if we can utilize quad-group functions for better performance
@@ -3452,7 +3452,7 @@ kernel void kernel_flash_attn_ext(
34523452
// reduce the warps sequentially
34533453
for (ushort sg = 1; sg < nsg; ++sg) {
34543454
float S = { 0.0f };
3455-
float M = { -__FLT16_MAX__/2 };
3455+
float M = { -__FLT_MAX__/2 };
34563456

34573457
threadgroup_barrier(mem_flags::mem_threadgroup);
34583458

@@ -3699,7 +3699,7 @@ kernel void kernel_flash_attn_ext_vec(
36993699

37003700
{
37013701
float S = 0.0f;
3702-
float M = -__FLT16_MAX__/2;
3702+
float M = -__FLT_MAX__/2;
37033703

37043704
// thread indices inside the simdgroup
37053705
const short tx = tiisg%NL;

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,7 @@ struct vk_device_struct {
262262
bool pipeline_robustness;
263263
vk::Device device;
264264
uint32_t vendor_id;
265+
vk::DriverId driver_id;
265266
vk_device_architecture architecture;
266267
vk_queue compute_queue;
267268
vk_queue transfer_queue;
@@ -1756,6 +1757,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
17561757
m_warptile_mmq_int = { 128, 64, 64, 32, subgroup_size_8, 32, 2, 2, 2, 1, subgroup_size_8 };
17571758
s_warptile_mmq_int = { subgroup_size_32, 32, 32, 32, 32, 32, 2, 2, 1, 1, subgroup_size_8 };
17581759

1760+
// chip specific tuning
1761+
if ((device->architecture == AMD_GCN) && (device->driver_id != vk::DriverId::eAmdProprietary)) {
1762+
m_warptile_mmq = m_warptile_mmq_int = { 256, 64, 64, 32, 16, 16, 2, 2, 2, 1, 16 };
1763+
}
1764+
17591765
l_mmq_wg_denoms = l_wg_denoms = {128, 128, 1 };
17601766
m_mmq_wg_denoms = m_wg_denoms = { 64, 64, 1 };
17611767
s_mmq_wg_denoms = s_wg_denoms = { 32, 32, 1 };
@@ -2678,6 +2684,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
26782684
device->physical_device.getProperties2(&props2);
26792685
device->properties = props2.properties;
26802686
device->vendor_id = device->properties.vendorID;
2687+
device->driver_id = driver_props.driverID;
26812688

26822689
const char* GGML_VK_FORCE_MAX_ALLOCATION_SIZE = getenv("GGML_VK_FORCE_MAX_ALLOCATION_SIZE");
26832690

0 commit comments

Comments
 (0)