@@ -408,18 +408,11 @@ def mfma_emu_bf16_f32(writer: Writer, C, B, A, c, a, b):
408408 writer (f'{ C } = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k({ A [0 ]} _p1, { B [0 ]} _p1, { C } , { c } , { a } , { b } );' )
409409
410410def mfma_emu_f16_f32 (writer : Writer , C , B , A , c , a , b ):
411- Ar = writer .varalloc ()
412- A1 = writer .varalloc ()
413- A2 = writer .varalloc ()
414- Br = writer .varalloc ()
415- B1 = writer .varalloc ()
416- B2 = writer .varalloc ()
417- writer (f'const f16x4 { Ar } = f16x4({ A } );' )
418- writer (f'const f16x4 { Br } = f16x4({ B } );' )
419- writer (f'const f16x4 { A1 } = f16x4({ A } );' )
420- writer (f'const f16x4 { B1 } = f16x4({ B } );' )
421- writer (f'const f16x4 { A2 } = f16x4({ A } - { A1 } );' )
422- writer (f'const f16x4 { B2 } = f16x4({ B } - { B1 } );' )
411+ writer (f'const auto [{ A [0 ]} _p0, { A [0 ]} _p1] = tensorforge::splitFloatx4F16({ A [0 ]} , { A [1 ]} , { A [2 ]} , { A [3 ]} );' )
412+ writer (f'const auto [{ B [0 ]} _p0, { B [0 ]} _p1] = tensorforge::splitFloatx4F16({ B [0 ]} , { B [1 ]} , { B [2 ]} , { B [3 ]} );' )
413+ writer (f'{ C } = __builtin_amdgcn_mfma_f32_4x4x4f16({ A [0 ]} _p0, { B [0 ]} _p0, { C } , { c } , { a } , { b } );' )
414+ writer (f'{ C } = __builtin_amdgcn_mfma_f32_4x4x4f16({ A [0 ]} _p1, { B [0 ]} _p0, { C } , { c } , { a } , { b } );' )
415+ writer (f'{ C } = __builtin_amdgcn_mfma_f32_4x4x4f16({ A [0 ]} _p0, { B [0 ]} _p1, { C } , { c } , { a } , { b } );' )
423416
424417def matmul32 (writer : Writer , C , B , A , M , N , K , kx , threads ):
425418 with writer .AnonymousScope ():
@@ -486,10 +479,14 @@ def write_matmul(block, start, cap):
486479 fB [kkk ] = B (writer , f'{ tmpB } _{ kkk } ' , i , k + kk + kkk )
487480 for kkk in range (dkk , block ):
488481 writer (f'float { tmpB } _{ kkk } = 0;' )
489- if True :
482+ if False :
490483 Ar = [f'{ tmpA } _{ k // threads } _{ kkk } ' for kkk in range (4 )]
491484 Br = [f'{ tmpB } _{ kkk } ' for kkk in range (4 )]
492485 mfma_emu_bf16_f32 (writer , tmpacc , Br , Ar , scale , kk // 4 , 0 )
486+ elif True :
487+ Ar = [f'{ tmpA } _{ k // threads } _{ kkk } ' for kkk in range (4 )]
488+ Br = [f'{ tmpB } _{ kkk } ' for kkk in range (4 )]
489+ mfma_emu_f16_f32 (writer , tmpacc , Br , Ar , scale , kk // 4 , 0 )
493490 else :
494491 for kkk in range (dkk ):
495492 if fB [kkk ]:
0 commit comments