Skip to content

Commit 41c9b26

Browse files
committed
vulkan: fix coopmat2 validation failures
mul mat and flash attention shaders were loading f32 types directly into A/B matrices, which happens to work but is technically invalid usage. For FA, we can load it as an Accumulator matrix and convert and this is not in the inner loop and is cheap enough. For mul mat, it's more efficient to do this conversion in a separate pass and have the input(s) be f16. coopmat2 requires SPIR-V 1.6 (related using to LocalSizeId). LocalSizeId requires maintenance4 be enabled, and SPIR-V 1.6 requires Vulkan 1.3.
1 parent bd38dde commit 41c9b26

File tree

4 files changed

+53
-56
lines changed

4 files changed

+53
-56
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 38 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,6 @@
2929

3030
#include "ggml-vulkan-shaders.hpp"
3131

32-
#define VK_API_VERSION VK_API_VERSION_1_2
33-
3432
#define CEIL_DIV(M, N) (((M) + (N)-1) / (N))
3533

3634
#define VK_VENDOR_ID_AMD 0x1002
@@ -1611,11 +1609,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
16111609
CREATE_MM(PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \
16121610
CREATE_MM(PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \
16131611

1614-
CREATE_MM(pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3)
1615-
CREATE_MM(pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3)
1616-
16171612
CREATE_MM2(pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3)
1618-
CREATE_MM2(pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3)
16191613
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
16201614
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
16211615
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
@@ -1628,21 +1622,18 @@ static void ggml_vk_load_shaders(vk_device& device) {
16281622
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3)
16291623
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
16301624

1631-
CREATE_MM(pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4)
16321625
CREATE_MM2(pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_id_push_constants, 4)
1633-
CREATE_MM2(pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_id_push_constants, 4)
1634-
1635-
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
1636-
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
1637-
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
1638-
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
1639-
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
1640-
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
1641-
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
1642-
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
1643-
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
1644-
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
1645-
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
1626+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
1627+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
1628+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
1629+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
1630+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
1631+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
1632+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
1633+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
1634+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
1635+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
1636+
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
16461637
#undef CREATE_MM
16471638
#undef CREATE_MM2
16481639
} else
@@ -2284,6 +2275,14 @@ static vk_device ggml_vk_get_device(size_t idx) {
22842275
}
22852276
#endif
22862277

2278+
VkPhysicalDeviceMaintenance4Features maint4_features {};
2279+
maint4_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_MAINTENANCE_4_FEATURES;
2280+
if (maintenance4_support) {
2281+
last_struct->pNext = (VkBaseOutStructure *)&maint4_features;
2282+
last_struct = (VkBaseOutStructure *)&maint4_features;
2283+
device_extensions.push_back("VK_KHR_maintenance4");
2284+
}
2285+
22872286
vkGetPhysicalDeviceFeatures2(device->physical_device, &device_features2);
22882287

22892288
device->fp16 = device->fp16 && vk12_features.shaderFloat16;
@@ -2659,7 +2658,14 @@ void ggml_vk_instance_init() {
26592658

26602659
vk_instance_initialized = true;
26612660

2662-
vk::ApplicationInfo app_info{ "ggml-vulkan", 1, nullptr, 0, VK_API_VERSION };
2661+
uint32_t api_version = vk::enumerateInstanceVersion();
2662+
2663+
if (api_version < VK_API_VERSION_1_2) {
2664+
std::cerr << "ggml_vulkan: Error: Vulkan 1.2 required." << std::endl;
2665+
GGML_ABORT("fatal error");
2666+
}
2667+
2668+
vk::ApplicationInfo app_info{ "ggml-vulkan", 1, nullptr, 0, api_version };
26632669

26642670
const std::vector<vk::ExtensionProperties> instance_extensions = vk::enumerateInstanceExtensionProperties();
26652671
const bool validation_ext = ggml_vk_instance_validation_ext_available(instance_extensions);
@@ -2969,7 +2975,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co
29692975
}
29702976
}
29712977

2972-
GGML_ASSERT(src1_type == GGML_TYPE_F32);
2978+
GGML_ASSERT(src1_type == GGML_TYPE_F32 || (ctx->device->coopmat2 && src1_type == GGML_TYPE_F16));
29732979

29742980
switch (src0_type) {
29752981
case GGML_TYPE_Q4_0:
@@ -3809,8 +3815,9 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
38093815
src1_uma = d_Qy != nullptr;
38103816
}
38113817

3812-
const bool x_non_contig = !ggml_vk_dim01_contiguous(src0);
3813-
// Reformat and convert to fp16 if src1 is non-contiguous, or for coopmat2 for better perf
3818+
// Reformat and convert to fp16 if non-contiguous, or for coopmat2 for better perf
3819+
const bool x_non_contig = (ctx->device->coopmat2 && src0->type == GGML_TYPE_F32) ||
3820+
!ggml_vk_dim01_contiguous(src0);
38143821
const bool y_non_contig = (ctx->device->coopmat2 && src1->type == GGML_TYPE_F32) ||
38153822
!ggml_vk_dim01_contiguous(src1);
38163823

@@ -4390,8 +4397,11 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
43904397
ids_uma = d_ids != nullptr;
43914398
}
43924399

4393-
const bool x_non_contig = !ggml_vk_dim01_contiguous(src0);
4394-
const bool y_non_contig = !ggml_vk_dim01_contiguous(src1);
4400+
// Reformat and convert to fp16 if non-contiguous, or for coopmat2 for better perf
4401+
const bool x_non_contig = (ctx->device->coopmat2 && src0->type == GGML_TYPE_F32) ||
4402+
!ggml_vk_dim01_contiguous(src0);
4403+
const bool y_non_contig = (ctx->device->coopmat2 && src1->type == GGML_TYPE_F32) ||
4404+
!ggml_vk_dim01_contiguous(src1);
43954405

43964406
const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig;
43974407

@@ -4401,7 +4411,8 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
44014411
const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !y_f32_kernel) || y_non_contig;
44024412

44034413
if (qx_needs_dequant) {
4404-
GGML_ABORT("fatal error");
4414+
// Fall back to dequant + f16 mulmat
4415+
mmp = ggml_vk_get_mul_mat_mat_id_pipeline(ctx, GGML_TYPE_F16, y_f32_kernel ? GGML_TYPE_F32 : GGML_TYPE_F16, (ggml_prec)dst->op_params[0]);
44054416
}
44064417

44074418
// Not implemented

ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ void main() {
146146
tensorLayoutK = setTensorLayoutDimensionNV(tensorLayoutK, KV, D);
147147
tensorLayoutV = setTensorLayoutDimensionNV(tensorLayoutV, KV, D);
148148

149-
coopmat<Q_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseA> Q;
149+
coopmat<Q_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> Q;
150150
coopmat<float16_t, gl_ScopeWorkgroup, Br, D, gl_MatrixUseA> Qf16;
151151

152152
uint32_t q_offset = iq2*p.nb02+iq3*p.nb03;

ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp

Lines changed: 9 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -57,17 +57,13 @@ layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
5757

5858
#if QUANT_K > 1
5959
#define DECODEFUNCA , dequantFuncA
60-
#define MAT_A_TYPE float16_t
6160

6261
#include "dequant_funcs_cm2.comp"
6362

6463
#else
6564
#define DECODEFUNCA
66-
#define MAT_A_TYPE A_TYPE
6765
#endif
6866

69-
#define MAT_B_TYPE B_TYPE
70-
7167
#ifdef MUL_MAT_ID
7268
layout (binding = 3) readonly buffer IDS {int data_ids[];};
7369

@@ -236,16 +232,13 @@ void main() {
236232

237233
for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) {
238234

239-
coopmat<MAT_A_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
240-
coopmat<MAT_B_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
235+
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
236+
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
241237

242238
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
243-
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a_ft = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA>(mat_a);
244-
245239
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose);
246-
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b_ft = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB>(mat_b);
247240

248-
sum = coopMatMulAdd(mat_a_ft, mat_b_ft, sum);
241+
sum = coopMatMulAdd(mat_a, mat_b, sum);
249242
}
250243
} else
251244
#endif // !defined(MUL_MAT_ID)
@@ -261,10 +254,8 @@ void main() {
261254
[[dont_unroll]]
262255
for (uint block_k = start_k; block_k < end_k; block_k += BK) {
263256

264-
coopmat<MAT_A_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
265-
coopmat<MAT_B_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
266-
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a_ft;
267-
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b_ft;
257+
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
258+
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
268259

269260
// Clamping is expensive, so detect different code paths for each combination
270261
// of A and B needing clamping.
@@ -281,33 +272,25 @@ void main() {
281272
#else
282273
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, (block_k & ~7), BK), tensorViewTranspose);
283274
#endif
284-
mat_a_ft = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA>(mat_a);
285-
mat_b_ft = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB>(mat_b);
286-
sum = coopMatMulAdd(mat_a_ft, mat_b_ft, sum);
275+
sum = coopMatMulAdd(mat_a, mat_b, sum);
287276
} else if (unclampedA && !unclampedB) {
288277
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, (block_k & ~7), BK) DECODEFUNCA);
289278
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose);
290279

291-
mat_a_ft = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA>(mat_a);
292-
mat_b_ft = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB>(mat_b);
293-
sum = coopMatMulAdd(mat_a_ft, mat_b_ft, sum);
280+
sum = coopMatMulAdd(mat_a, mat_b, sum);
294281
} else if (!unclampedA && unclampedB) {
295282
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);
296283
#ifdef MUL_MAT_ID
297284
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeFuncB);
298285
#else
299286
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, (block_k & ~7), BK), tensorViewTranspose);
300287
#endif
301-
mat_a_ft = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA>(mat_a);
302-
mat_b_ft = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB>(mat_b);
303-
sum = coopMatMulAdd(mat_a_ft, mat_b_ft, sum);
288+
sum = coopMatMulAdd(mat_a, mat_b, sum);
304289
} else if (!unclampedA && !unclampedB) {
305290
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);
306291
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose);
307292

308-
mat_a_ft = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA>(mat_a);
309-
mat_b_ft = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB>(mat_b);
310-
sum = coopMatMulAdd(mat_a_ft, mat_b_ft, sum);
293+
sum = coopMatMulAdd(mat_a, mat_b, sum);
311294
}
312295
}
313296
}

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -316,8 +316,11 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool
316316
// For aligned matmul loads
317317
std::string load_vec_a = (coopmat2 || tname == "f32" || tname == "f16") ? load_vec : "2";
318318

319-
string_to_spv(shader_name + "_" + tname + "_f32", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat, coopmat2, f16acc);
320-
string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
319+
// don't generate f32 variants for coopmat2
320+
if (!coopmat2) {
321+
string_to_spv(shader_name + "_" + tname + "_f32", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat, coopmat2, f16acc);
322+
string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
323+
}
321324

322325
if (tname != "f16" && tname != "f32") {
323326
string_to_spv(shader_name + "_" + tname + "_f16", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat, coopmat2, f16acc);

0 commit comments

Comments
 (0)