Skip to content

Commit 89a6e3d

Browse files
committed
Working AMD 16x16 instruction transpose
1 parent 19f933e commit 89a6e3d

File tree

1 file changed

+41
-3
lines changed
  • tensorforge/include/tensorforge_device

1 file changed

+41
-3
lines changed

tensorforge/include/tensorforge_device/hip.h

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -676,11 +676,49 @@ transpose16x16b32(T &w1, T &w2, T &w3, T &w4, T &w5, T &w6, T &w7, T &w8, T &w9,
676676
transpose4x4b32(v9, v10, v11, v12, w9, w10, w11, w12);
677677
transpose4x4b32(v13, v14, v15, v16, w13, w14, w15, w16);
678678

679-
// transpose 8x8
679+
// from here on: DPP and row control suffice
680680

681-
// DPP and row control suffice here
681+
// transpose 8x8
682682

683-
// const T u1 = dppUpdate<0x128, 0b1010, 0b1111, true>(v1, v5);
683+
const T u1 = dppUpdate<0x124, 0b1111, 0b1010, true>(v5, v1);
684+
const T u2 = dppUpdate<0x124, 0b1111, 0b1010, true>(v6, v2);
685+
const T u3 = dppUpdate<0x124, 0b1111, 0b1010, true>(v7, v3);
686+
const T u4 = dppUpdate<0x124, 0b1111, 0b1010, true>(v8, v4);
687+
688+
const T u5 = dppUpdate<0x12c, 0b1111, 0b0101, true>(v1, v5);
689+
const T u6 = dppUpdate<0x12c, 0b1111, 0b0101, true>(v2, v6);
690+
const T u7 = dppUpdate<0x12c, 0b1111, 0b0101, true>(v3, v7);
691+
const T u8 = dppUpdate<0x12c, 0b1111, 0b0101, true>(v4, v8);
692+
693+
const T u9 = dppUpdate<0x124, 0b1111, 0b1010, true>(v13, v9);
694+
const T u10 = dppUpdate<0x124, 0b1111, 0b1010, true>(v14, v10);
695+
const T u11 = dppUpdate<0x124, 0b1111, 0b1010, true>(v15, v11);
696+
const T u12 = dppUpdate<0x124, 0b1111, 0b1010, true>(v16, v12);
697+
698+
const T u13 = dppUpdate<0x12c, 0b1111, 0b0101, true>(v9, v13);
699+
const T u14 = dppUpdate<0x12c, 0b1111, 0b0101, true>(v10, v14);
700+
const T u15 = dppUpdate<0x12c, 0b1111, 0b0101, true>(v11, v15);
701+
const T u16 = dppUpdate<0x12c, 0b1111, 0b0101, true>(v12, v16);
702+
703+
// transpose 16x16
704+
705+
w1 = dppUpdate<0x128, 0b1111, 0b1100, true>(u9, u1);
706+
w2 = dppUpdate<0x128, 0b1111, 0b1100, true>(u10, u2);
707+
w3 = dppUpdate<0x128, 0b1111, 0b1100, true>(u11, u3);
708+
w4 = dppUpdate<0x128, 0b1111, 0b1100, true>(u12, u4);
709+
w5 = dppUpdate<0x128, 0b1111, 0b1100, true>(u13, u5);
710+
w6 = dppUpdate<0x128, 0b1111, 0b1100, true>(u14, u6);
711+
w7 = dppUpdate<0x128, 0b1111, 0b1100, true>(u15, u7);
712+
w8 = dppUpdate<0x128, 0b1111, 0b1100, true>(u16, u8);
713+
714+
w9 = dppUpdate<0x128, 0b1111, 0b0011, true>(u1, u9);
715+
w10 = dppUpdate<0x128, 0b1111, 0b0011, true>(u2, u10);
716+
w11 = dppUpdate<0x128, 0b1111, 0b0011, true>(u3, u11);
717+
w12 = dppUpdate<0x128, 0b1111, 0b0011, true>(u4, u12);
718+
w13 = dppUpdate<0x128, 0b1111, 0b0011, true>(u5, u13);
719+
w14 = dppUpdate<0x128, 0b1111, 0b0011, true>(u6, u14);
720+
w15 = dppUpdate<0x128, 0b1111, 0b0011, true>(u7, u15);
721+
w16 = dppUpdate<0x128, 0b1111, 0b0011, true>(u8, u16);
684722
}
685723

686724
#define CM4STR(p1, p2, p3, p4, c, a, b) \

0 commit comments

Comments
 (0)