@@ -544,34 +544,51 @@ def hfma(writer: Writer, C, A, B, repeat, datatype, threads, ctx):
544544 if b is not None :
545545 func (writer , c , a , b , j )
546546
547- def wmma3atom (threads ):
547+ def wmma3atom (writer , A , B , C , threads ):
548+
549+ a = writer .varalloc ()
550+ b = writer .varalloc ()
551+ c = writer .varalloc ()
552+
548553 assert threads == 32
549554
550555 N = 16
551556 M = 16
552557 K = 16
553558
554- for i in range (N ):
555- writer (f'const auto { a } _{ i } = tensorforge::broadcast<32, 16, 0>({ A } _{ i } );' )
556- for j in range (N ):
557- writer (f'const auto { b } _{ j } = tensorforge::broadcast<32, 16, 0>({ B } _{ j } );' )
558-
559- writer (f'tensorforge::transpose16x16({ "," .join (f"{ b } _{ i } " for i in range (N ))} );' )
560-
561- writer (f'VectorT<short, 16> { a } _p1;' )
562- writer (f'VectorT<short, 16> { a } _p2;' )
563- writer (f'VectorT<short, 16> { a } _p3;' )
564- writer (f'VectorT<short, 16> { b } _p1;' )
565- writer (f'VectorT<short, 16> { b } _p2;' )
566- writer (f'VectorT<short, 16> { b } _p3;' )
567-
568- writer (f'VectorT<float, 8> { c } { "{}" } ;' )
569- writer (f'{ c } = __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32({ a } _p1, { b } _p1, { c } );' )
570- writer (f'{ c } = __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32({ a } _p2, { b } _p1, { c } );' )
571- writer (f'{ c } = __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32({ a } _p1, { b } _p2, { c } );' )
572- writer (f'{ c } = __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32({ a } _p3, { b } _p1, { c } );' )
573- writer (f'{ c } = __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32({ a } _p1, { b } _p3, { c } );' )
574- writer (f'{ c } = __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32({ a } _p2, { b } _p2, { c } );' )
559+ for m in range (2 ):
560+ with writer .AnonymousScope ():
561+ for i in range (N ):
562+ writer (f'const auto { a } _{ i } = tensorforge::broadcast<32, 16, { m } >({ A } _{ i } );' )
563+ for j in range (N ):
564+ writer (f'const auto { b } _{ j } = tensorforge::broadcast<32, 16, { m } >({ B } _{ j } );' )
565+
566+ writer (f'tensorforge::transpose16x16({ "," .join (f"{ b } _{ i } " for i in range (N ))} );' )
567+
568+ writer (f'VectorT<short, 16> { a } _p1{ "{}" } ;' )
569+ writer (f'VectorT<short, 16> { a } _p2{ "{}" } ;' )
570+ writer (f'VectorT<short, 16> { a } _p3{ "{}" } ;' )
571+ writer (f'VectorT<short, 16> { b } _p1{ "{}" } ;' )
572+ writer (f'VectorT<short, 16> { b } _p2{ "{}" } ;' )
573+ writer (f'VectorT<short, 16> { b } _p3{ "{}" } ;' )
574+
575+ for i in range (N ):
576+ writer (f'[{ a } _p1[{ i } ], { a } _p2[{ i } ], { a } _p3[{ i } ]] = splitFloatBF16({ a } _{ i } );' )
577+ for i in range (N ):
578+ writer (f'[{ b } _p1[{ i } ], { b } _p2[{ i } ], { b } _p3[{ i } ]] = splitFloatBF16({ b } _{ i } );' )
579+
580+ writer (f'VectorT<float, 8> { c } { "{}" } ;' )
581+ writer (f'{ c } = __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32({ a } _p1, { b } _p1, { c } );' )
582+ writer (f'{ c } = __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32({ a } _p2, { b } _p1, { c } );' )
583+ writer (f'{ c } = __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32({ a } _p1, { b } _p2, { c } );' )
584+ writer (f'{ c } = __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32({ a } _p3, { b } _p1, { c } );' )
585+ writer (f'{ c } = __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32({ a } _p1, { b } _p3, { c } );' )
586+ writer (f'{ c } = __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32({ a } _p2, { b } _p2, { c } );' )
587+
588+ for j in range (N ):
589+ writer (f'const auto { c } _{ j } = tensorforge::broadcast<32, 16, { m } >({ c } [{ j } ]);' )
590+
591+
575592
576593 # TODO: gfx1200, f'__builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12'
577594
0 commit comments