@@ -392,34 +392,34 @@ def mfma_emu_int8(writer: Writer, C, B, A, c, a, b):
392392 c = writer .varalloc ()
393393 writer (f'const auto { a } = static_cast<uint8x4_t>({ Aa } % { x } );' )
394394 writer (f'const auto { b } = static_cast<uint8x4_t>({ Ba } % { x } );' )
395- writer (f'{ c } = __builtin_amdgcn_mfma_i32_4x4x4i8(get_native_vector( { a } ), get_native_vector( { b } ) , 0, { c } , { a } , { b } );' )
395+ writer (f'{ c } = __builtin_amdgcn_mfma_i32_4x4x4i8({ a } , { b } , 0, { c } , { a } , { b } );' )
396396 writer (f'{ Ca } += { c } * { y } ;' )
397397
398398 # TODO: scale back
399399
400400def mfma_emu_bf16_f32 (writer : Writer , C , B , A , c , a , b ):
401+ writer (f'const auto [{ A [0 ]} _p0, { A [0 ]} _p1, { A [0 ]} _p2] = tensorforge::splitFloatx4BF16({ A [0 ]} , { A [1 ]} , { A [2 ]} , { A [3 ]} );' )
402+ writer (f'const auto [{ B [0 ]} _p0, { B [0 ]} _p1, { B [0 ]} _p2] = tensorforge::splitFloatx4BF16({ B [0 ]} , { B [1 ]} , { B [2 ]} , { B [3 ]} );' )
403+ writer (f'{ C } = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k({ A [0 ]} _p0, { B [0 ]} _p0, { C } , { c } , { a } , { b } );' )
404+ writer (f'{ C } = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k({ A [0 ]} _p0, { B [0 ]} _p1, { C } , { c } , { a } , { b } );' )
405+ writer (f'{ C } = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k({ A [0 ]} _p1, { B [0 ]} _p0, { C } , { c } , { a } , { b } );' )
406+ writer (f'{ C } = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k({ A [0 ]} _p0, { B [0 ]} _p2, { C } , { c } , { a } , { b } );' )
407+ writer (f'{ C } = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k({ A [0 ]} _p2, { B [0 ]} _p0, { C } , { c } , { a } , { b } );' )
408+ writer (f'{ C } = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k({ A [0 ]} _p1, { B [0 ]} _p1, { C } , { c } , { a } , { b } );' )
409+
410+ def mfma_emu_f16_f32 (writer : Writer , C , B , A , c , a , b ):
411+ Ar = writer .varalloc ()
401412 A1 = writer .varalloc ()
402413 A2 = writer .varalloc ()
403- A3 = writer .varalloc ()
414+ Br = writer .varalloc ()
404415 B1 = writer .varalloc ()
405416 B2 = writer .varalloc ()
406- B3 = writer .varalloc ()
407- Ar = writer .varalloc ()
408- Br = writer .varalloc ()
409- writer (f'const bfloat16x4 { A1 } = bfloat16x4({ A } );' )
410- writer (f'const bfloat16x4 { B1 } = bfloat16x4({ B } );' )
411- writer (f'const bfloat16x4 { Ar } = { A } - { A1 } ;' )
412- writer (f'const bfloat16x4 { Br } = { B } - { B1 } ;' )
413- writer (f'const bfloat16x4 { A2 } = bfloat16x4({ Ar } );' )
414- writer (f'const bfloat16x4 { B2 } = bfloat16x4({ Br } );' )
415- writer (f'const bfloat16x4 { A3 } = bfloat16x4({ Ar } - { A2 } );' )
416- writer (f'const bfloat16x4 { B3 } = bfloat16x4({ Br } - { B2 } );' )
417- writer (f'{ C } = __builtin_amdgcn_mfma_f32_4x4x4bf16(get_native_vector({ A1 } ), get_native_vector({ B1 } ), { C } , { c } , { a } , { b } );' )
418- writer (f'{ C } = __builtin_amdgcn_mfma_f32_4x4x4bf16(get_native_vector({ A1 } ), get_native_vector({ B2 } ), { C } , { c } , { a } , { b } );' )
419- writer (f'{ C } = __builtin_amdgcn_mfma_f32_4x4x4bf16(get_native_vector({ A2 } ), get_native_vector({ B1 } ), { C } , { c } , { a } , { b } );' )
420- writer (f'{ C } = __builtin_amdgcn_mfma_f32_4x4x4bf16(get_native_vector({ A1 } ), get_native_vector({ B3 } ), { C } , { c } , { a } , { b } );' )
421- writer (f'{ C } = __builtin_amdgcn_mfma_f32_4x4x4bf16(get_native_vector({ A3 } ), get_native_vector({ B1 } ), { C } , { c } , { a } , { b } );' )
422- writer (f'{ C } = __builtin_amdgcn_mfma_f32_4x4x4bf16(get_native_vector({ A2 } ), get_native_vector({ B2 } ), { C } , { c } , { a } , { b } );' )
417+ writer (f'const f16x4 { Ar } = f16x4({ A } );' )
418+ writer (f'const f16x4 { Br } = f16x4({ B } );' )
419+ writer (f'const f16x4 { A1 } = f16x4({ A } );' )
420+ writer (f'const f16x4 { B1 } = f16x4({ B } );' )
421+ writer (f'const f16x4 { A2 } = f16x4({ A } - { A1 } );' )
422+ writer (f'const f16x4 { B2 } = f16x4({ B } - { B1 } );' )
423423
424424def matmul32 (writer : Writer , C , B , A , M , N , K , kx , threads ):
425425 with writer .AnonymousScope ():
@@ -449,11 +449,13 @@ def write_matmul(block, start, cap):
449449 }[threads ]
450450 }[block ]
451451 fn = {
452+ 1 : f'fmacdpp16<0>()' ,
452453 4 : '__builtin_amdgcn_mfma_f32_4x4x1f32' ,
453454 16 : '__builtin_amdgcn_mfma_f32_16x16x1f32' ,
454455 32 : '__builtin_amdgcn_mfma_f32_32x32x1f32'
455456 }[block ]
456457 tp = {
458+ 1 : lambda tmpA : '' ,
457459 4 : lambda tmpA : f'tensorforge::transpose4x4b32({ tmpA } _0, { tmpA } _1, { tmpA } _2, { tmpA } _3, { tmpA } _0, { tmpA } _1, { tmpA } _2, { tmpA } _3)' ,
458460 16 : lambda tmpA : f'tensorforge::transpose16x16b32({ ", " .join (f"{ tmpA } _{ i } " for i in range (16 ))} )' ,
459461 32 : lambda tmpA : f'tensorforge::transpose32x32b32({ ", " .join (f"{ tmpA } _{ i } " for i in range (32 ))} )'
@@ -474,21 +476,33 @@ def write_matmul(block, start, cap):
474476 for i in range (0 , M ):
475477 with writer .AnonymousScope ():
476478 writer (f'tensorforge::VectorT<float, { block } > { tmpacc } { "{}" } ;' )
477- for k in range (0 , K , threads ):
478- dk = min (threads , K - k )
479+ for k in range (0 , K + kx , threads ):
480+ dk = min (threads , K + kx - k )
479481 for kk in range (0 , dk , block ):
480482 with writer .AnonymousScope ():
481483 fB = [False ] * block
482- for kkk in range (min (block , dk - kk )):
484+ dkk = min (block , dk - kk )
485+ for kkk in range (dkk ):
483486 fB [kkk ] = B (writer , f'{ tmpB } _{ kkk } ' , i , k + kk + kkk )
484- for kkk in range (min (block , dk - kk )):
485- if fB [kkk ]:
486- trueK = k + kk + kkk + kx
487- km = trueK // threads
488- kkm = ((trueK % threads ) // block )
489- kkkm = trueK % block
490- # the index for tmpB is correct
491- writer (f'{ tmpacc } = { fn } ({ tmpA } _{ km } _{ kkkm } , { tmpB } _{ kkk } , { tmpacc } , { scale } , { kkm } , 0);' )
487+ for kkk in range (dkk , block ):
488+ writer (f'float { tmpB } _{ kkk } = 0;' )
489+ if True :
490+ Ar = [f'{ tmpA } _{ k // threads } _{ kkk } ' for kkk in range (4 )]
491+ Br = [f'{ tmpB } _{ kkk } ' for kkk in range (4 )]
492+ mfma_emu_bf16_f32 (writer , tmpacc , Br , Ar , scale , kk // 4 , 0 )
493+ else :
494+ for kkk in range (dkk ):
495+ if fB [kkk ]:
496+ trueK = k + kk + kkk #+ kx
497+ km = trueK // threads
498+ kkm = ((trueK % threads ) // block )
499+ kkkm = trueK % block
500+
501+ assert km == k
502+ assert kkm == kk
503+ assert kkkm == kkk
504+ # the index for tmpB is correct
505+ writer (f'{ tmpacc } = { fn } ({ tmpA } _{ km } _{ kkkm } , { tmpB } _{ kkk } , { tmpacc } , { scale } , { kkm } , 0);' )
492506
493507 for jj in range (min (block , N - j )):
494508 C (writer , f'{ tmpacc } [{ jj } ]' , i , j + jj )
@@ -500,7 +514,9 @@ def write_matmul(block, start, cap):
500514 #if N >= 16 and threads >= 16:
501515 # write_matmul(16, start, True)
502516 # start += (N // 16) * 16
503- write_matmul (4 , start , False )
517+ cap4 = False #N % 4 < 2
518+ write_matmul (4 , start , cap4 )
519+ # write_matmul(1, )
504520
505521def fmadpp16 (writer , C , A , B , row ):
506522 writer (f'tensorforge::fmacdpp16<{ row } >({ C } , { A } , { B } );' )
0 commit comments