Skip to content

Commit ae43f64

Browse files
committed
Towards more AMD and Intel matrix ops
1 parent 685f544 commit ae43f64

File tree

3 files changed

+68
-22
lines changed
  • tensorforge

3 files changed

+68
-22
lines changed

tensorforge/backend/instructions/compute/primitives/amd.py

Lines changed: 39 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -544,34 +544,51 @@ def hfma(writer: Writer, C, A, B, repeat, datatype, threads, ctx):
544544
if b is not None:
545545
func(writer, c, a, b, j)
546546

547-
def wmma3atom(threads):
547+
def wmma3atom(writer, A, B, C, threads):
548+
549+
a = writer.varalloc()
550+
b = writer.varalloc()
551+
c = writer.varalloc()
552+
548553
assert threads == 32
549554

550555
N = 16
551556
M = 16
552557
K = 16
553558

554-
for i in range(N):
555-
writer(f'const auto {a}_{i} = tensorforge::broadcast<32, 16, 0>({A}_{i});')
556-
for j in range(N):
557-
writer(f'const auto {b}_{j} = tensorforge::broadcast<32, 16, 0>({B}_{j});')
558-
559-
writer(f'tensorforge::transpose16x16({",".join(f"{b}_{i}" for i in range(N))});')
560-
561-
writer(f'VectorT<short, 16> {a}_p1;')
562-
writer(f'VectorT<short, 16> {a}_p2;')
563-
writer(f'VectorT<short, 16> {a}_p3;')
564-
writer(f'VectorT<short, 16> {b}_p1;')
565-
writer(f'VectorT<short, 16> {b}_p2;')
566-
writer(f'VectorT<short, 16> {b}_p3;')
567-
568-
writer(f'VectorT<float, 8> {c}{"{}"};')
569-
writer(f'{c} = __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32({a}_p1, {b}_p1, {c});')
570-
writer(f'{c} = __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32({a}_p2, {b}_p1, {c});')
571-
writer(f'{c} = __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32({a}_p1, {b}_p2, {c});')
572-
writer(f'{c} = __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32({a}_p3, {b}_p1, {c});')
573-
writer(f'{c} = __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32({a}_p1, {b}_p3, {c});')
574-
writer(f'{c} = __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32({a}_p2, {b}_p2, {c});')
559+
for m in range(2):
560+
with writer.AnonymousScope():
561+
for i in range(N):
562+
writer(f'const auto {a}_{i} = tensorforge::broadcast<32, 16, {m}>({A}_{i});')
563+
for j in range(N):
564+
writer(f'const auto {b}_{j} = tensorforge::broadcast<32, 16, {m}>({B}_{j});')
565+
566+
writer(f'tensorforge::transpose16x16({",".join(f"{b}_{i}" for i in range(N))});')
567+
568+
writer(f'VectorT<short, 16> {a}_p1{"{}"};')
569+
writer(f'VectorT<short, 16> {a}_p2{"{}"};')
570+
writer(f'VectorT<short, 16> {a}_p3{"{}"};')
571+
writer(f'VectorT<short, 16> {b}_p1{"{}"};')
572+
writer(f'VectorT<short, 16> {b}_p2{"{}"};')
573+
writer(f'VectorT<short, 16> {b}_p3{"{}"};')
574+
575+
for i in range(N):
576+
writer(f'[{a}_p1[{i}], {a}_p2[{i}], {a}_p3[{i}]] = splitFloatBF16({a}_{i});')
577+
for i in range(N):
578+
writer(f'[{b}_p1[{i}], {b}_p2[{i}], {b}_p3[{i}]] = splitFloatBF16({b}_{i});')
579+
580+
writer(f'VectorT<float, 8> {c}{"{}"};')
581+
writer(f'{c} = __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32({a}_p1, {b}_p1, {c});')
582+
writer(f'{c} = __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32({a}_p2, {b}_p1, {c});')
583+
writer(f'{c} = __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32({a}_p1, {b}_p2, {c});')
584+
writer(f'{c} = __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32({a}_p3, {b}_p1, {c});')
585+
writer(f'{c} = __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32({a}_p1, {b}_p3, {c});')
586+
writer(f'{c} = __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32({a}_p2, {b}_p2, {c});')
587+
588+
for j in range(N):
589+
writer(f'const auto {c}_{j} = tensorforge::broadcast<32, 16, {m}>({c}[{j}]);')
590+
591+
575592

576593
# TODO: gfx1200, f'__builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12'
577594

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
2+
def dpas(C, B, A, rc, sd):
3+
# cf. https://github.com/intel/intel-graphics-compiler/blob/master/documentation/visa/instructions/DPAS.md
4+
# sd == depth == k * elemsIn32Bit
5+
# rc == m
6+
writer(f'asm("DPAS.tf32.tf32.{sd}.{rc} (16) %[D], %[C], %[B], %[A]" : [D]"=f"({C}) : [C]"f"({C}), [B]"d"({B}), [A]"d"({A}) :);')
7+
8+
def matmul(writer, C, A, B, M, N, K, kx, threads, dtype, sparse, ctx):
9+
10+
11+
rc = 8
12+
sd = 8
13+
14+
dpas(C, A, B, rc, sd)
15+
16+
# TODO

tensorforge/include/tensorforge_device/hip.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include <hip/hip_cooperative_groups.h>
66

77
#include <type_traits>
8+
#include <utility>
89

910
#include "base.h"
1011

@@ -807,4 +808,16 @@ class Loader {
807808
};
808809
*/
809810

811+
std::tuple<short, short, short> splitFloatBF16(float input) {
812+
const auto i1 = static_cast<__bf16>(input);
813+
const auto i1r = input - static_cast<float>(i1);
814+
const auto i2 = static_cast<__bf16>(i1r);
815+
const auto i2r = i1r - static_cast<float>(i2);
816+
const auto i3 = static_cast<__bf16>(i2r);
817+
const auto r1 = *reinterpret_cast<const short *>(&i1);
818+
const auto r2 = *reinterpret_cast<const short *>(&i2);
819+
const auto r3 = *reinterpret_cast<const short *>(&i3);
820+
return {r1, r2, r3};
821+
}
822+
810823
} // namespace tensorforge

0 commit comments

Comments
 (0)