Skip to content

Commit 570f24d

Browse files
authored
[AMD] Improve register usage in Float8 conversions (#7527)
This PR updates the intrinsics of Float8 conversions to make full usage of f8x4 registers.
1 parent 96e91d4 commit 570f24d

File tree

2 files changed

+284
-219
lines changed

2 files changed

+284
-219
lines changed

test/Conversion/amd/fp_to_fp.mlir

Lines changed: 64 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -62,22 +62,40 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
6262
tt.func @downcast_to_f8(%arg0: tensor<8x8xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>>,
6363
%arg1: tensor<8x8xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>>,
6464
%arg2: tensor<8x8xbf16, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>>) {
65-
// GFX950-COUNT-4: rocdl.cvt.scalef32.pk.bf8.f32
65+
// GFX950: rocdl.cvt.scalef32.pk.bf8.f32 %{{.*}}, %{{.*}}, %{{.*}} -> %{{.*}}[false]
66+
// GFX950: rocdl.cvt.scalef32.pk.bf8.f32 %{{.*}}, %{{.*}}, %{{.*}} -> %{{.*}}[true]
67+
// GFX950: rocdl.cvt.scalef32.pk.bf8.f32 %{{.*}}, %{{.*}}, %{{.*}} -> %{{.*}}[false]
68+
// GFX950: rocdl.cvt.scalef32.pk.bf8.f32 %{{.*}}, %{{.*}}, %{{.*}} -> %{{.*}}[true]
6669
%0 = tt.fp_to_fp %arg0, rounding = rtne : tensor<8x8xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>> -> tensor<8x8xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>>
6770

68-
// GFX950-COUNT-4: rocdl.cvt.scalef32.pk.bf8.f16
71+
// GFX950: rocdl.cvt.scalef32.pk.bf8.f16 %{{.*}}, %{{.*}} -> %{{.*}}[false]
72+
// GFX950: rocdl.cvt.scalef32.pk.bf8.f16 %{{.*}}, %{{.*}} -> %{{.*}}[true]
73+
// GFX950: rocdl.cvt.scalef32.pk.bf8.f16 %{{.*}}, %{{.*}} -> %{{.*}}[false]
74+
// GFX950: rocdl.cvt.scalef32.pk.bf8.f16 %{{.*}}, %{{.*}} -> %{{.*}}[true]
6975
%1 = tt.fp_to_fp %arg1, rounding = rtne : tensor<8x8xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>> -> tensor<8x8xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>>
7076

71-
// GFX950-COUNT-4: rocdl.cvt.scalef32.pk.bf8.bf16
77+
// GFX950: rocdl.cvt.scalef32.pk.bf8.bf16 %{{.*}}, %{{.*}} -> %{{.*}}[false]
78+
// GFX950: rocdl.cvt.scalef32.pk.bf8.bf16 %{{.*}}, %{{.*}} -> %{{.*}}[true]
79+
// GFX950: rocdl.cvt.scalef32.pk.bf8.bf16 %{{.*}}, %{{.*}} -> %{{.*}}[false]
80+
// GFX950: rocdl.cvt.scalef32.pk.bf8.bf16 %{{.*}}, %{{.*}} -> %{{.*}}[true]
7281
%2 = tt.fp_to_fp %arg2, rounding = rtne : tensor<8x8xbf16, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>> -> tensor<8x8xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>>
7382

74-
// GFX950-COUNT-4: rocdl.cvt.scalef32.pk.fp8.f32
83+
// GFX950: rocdl.cvt.scalef32.pk.fp8.f32 %{{.*}}, %{{.*}}, %{{.*}} -> %{{.*}}[false]
84+
// GFX950: rocdl.cvt.scalef32.pk.fp8.f32 %{{.*}}, %{{.*}}, %{{.*}} -> %{{.*}}[true]
85+
// GFX950: rocdl.cvt.scalef32.pk.fp8.f32 %{{.*}}, %{{.*}}, %{{.*}} -> %{{.*}}[false]
86+
// GFX950: rocdl.cvt.scalef32.pk.fp8.f32 %{{.*}}, %{{.*}}, %{{.*}} -> %{{.*}}[true]
7587
%3 = tt.fp_to_fp %arg0, rounding = rtne : tensor<8x8xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>> -> tensor<8x8xf8E4M3FN, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>>
7688

77-
// GFX950-COUNT-4: rocdl.cvt.scalef32.pk.fp8.f16
89+
// GFX950: rocdl.cvt.scalef32.pk.fp8.f16 %{{.*}}, %{{.*}} -> %{{.*}}[false]
90+
// GFX950: rocdl.cvt.scalef32.pk.fp8.f16 %{{.*}}, %{{.*}} -> %{{.*}}[true]
91+
// GFX950: rocdl.cvt.scalef32.pk.fp8.f16 %{{.*}}, %{{.*}} -> %{{.*}}[false]
92+
// GFX950: rocdl.cvt.scalef32.pk.fp8.f16 %{{.*}}, %{{.*}} -> %{{.*}}[true]
7893
%4 = tt.fp_to_fp %arg1, rounding = rtne : tensor<8x8xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>> -> tensor<8x8xf8E4M3FN, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>>
7994

80-
// GFX950-COUNT-4: rocdl.cvt.scalef32.pk.fp8.bf16
95+
// GFX950: rocdl.cvt.scalef32.pk.fp8.bf16 %{{.*}}, %{{.*}} -> %{{.*}}[false]
96+
// GFX950: rocdl.cvt.scalef32.pk.fp8.bf16 %{{.*}}, %{{.*}} -> %{{.*}}[true]
97+
// GFX950: rocdl.cvt.scalef32.pk.fp8.bf16 %{{.*}}, %{{.*}} -> %{{.*}}[false]
98+
// GFX950: rocdl.cvt.scalef32.pk.fp8.bf16 %{{.*}}, %{{.*}} -> %{{.*}}[true]
8199
%5 = tt.fp_to_fp %arg2, rounding = rtne : tensor<8x8xbf16, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>> -> tensor<8x8xf8E4M3FN, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>>
82100
tt.return
83101
}
@@ -89,7 +107,10 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
89107
#blocked2 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
90108
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} {
91109
tt.func @downcast_to_bf8(%arg0: tensor<8x8xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>>) {
92-
// GFX942-COUNT-4: rocdl.cvt.pk.bf8.f32
110+
// GFX942: rocdl.cvt.pk.bf8.f32 %{{.*}}, %{{.*}} -> %{{.*}}[false]
111+
// GFX942: rocdl.cvt.pk.bf8.f32 %{{.*}}, %{{.*}} -> %{{.*}}[true]
112+
// GFX942: rocdl.cvt.pk.bf8.f32 %{{.*}}, %{{.*}} -> %{{.*}}[false]
113+
// GFX942: rocdl.cvt.pk.bf8.f32 %{{.*}}, %{{.*}} -> %{{.*}}[true]
93114
// GFX950-COUNT-16: llvm.trunc %{{.+}} : i32 to i8
94115
%6 = tt.fp_to_fp %arg0, rounding = rtne : tensor<8x8xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>> -> tensor<8x8xf8E5M2FNUZ, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>>
95116
tt.return
@@ -102,7 +123,10 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
102123
#blocked2 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
103124
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} {
104125
tt.func @f32_to_f8(%arg0: tensor<8x8xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>>) {
105-
// GFX942-COUNT-4: rocdl.cvt.pk.fp8.f32
126+
// GFX942: rocdl.cvt.pk.fp8.f32 %{{.*}}, %{{.*}} -> %{{.*}}[false]
127+
// GFX942: rocdl.cvt.pk.fp8.f32 %{{.*}}, %{{.*}} -> %{{.*}}[true]
128+
// GFX942: rocdl.cvt.pk.fp8.f32 %{{.*}}, %{{.*}} -> %{{.*}}[false]
129+
// GFX942: rocdl.cvt.pk.fp8.f32 %{{.*}}, %{{.*}} -> %{{.*}}[true]
106130
// GFX950-COUNT-16: llvm.trunc %{{.+}} : i32 to i8
107131
%7 = tt.fp_to_fp %arg0, rounding = rtne : tensor<8x8xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>> -> tensor<8x8xf8E4M3FNUZ, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>>
108132
tt.return
@@ -118,28 +142,52 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
118142
%arg1: tensor<8x8xf8E4M3FN, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>>,
119143
%arg2: tensor<8x8xf8E5M2FNUZ, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>>,
120144
%arg3: tensor<8x8xf8E4M3FNUZ, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>>) {
121-
// GFX950-COUNT-4: rocdl.cvt.scalef32.pk.f32.bf8
145+
// GFX950: rocdl.cvt.scalef32.pk.f32.bf8 %[[VR1:.*]][false]
146+
// GFX950: rocdl.cvt.scalef32.pk.f32.bf8 %[[VR1]][true]
147+
// GFX950: rocdl.cvt.scalef32.pk.f32.bf8 %[[VR2:.*]][false]
148+
// GFX950: rocdl.cvt.scalef32.pk.f32.bf8 %[[VR2]][true]
122149
%0 = tt.fp_to_fp %arg0 : tensor<8x8xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>> -> tensor<8x8xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>>
123150

124-
// GFX950-COUNT-4: rocdl.cvt.scalef32.pk.f16.bf8
151+
// GFX950: rocdl.cvt.scalef32.pk.f16.bf8 %[[VR3:.*]][false]
152+
// GFX950: rocdl.cvt.scalef32.pk.f16.bf8 %[[VR3]][true]
153+
// GFX950: rocdl.cvt.scalef32.pk.f16.bf8 %[[VR4:.*]][false]
154+
// GFX950: rocdl.cvt.scalef32.pk.f16.bf8 %[[VR4]][true]
125155
%1 = tt.fp_to_fp %arg0 : tensor<8x8xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>> -> tensor<8x8xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>>
126156

127-
// GFX950-COUNT-4: rocdl.cvt.scalef32.pk.bf16.bf8
157+
// GFX950: rocdl.cvt.scalef32.pk.bf16.bf8 %[[VR5:.*]][false]
158+
// GFX950: rocdl.cvt.scalef32.pk.bf16.bf8 %[[VR5]][true]
159+
// GFX950: rocdl.cvt.scalef32.pk.bf16.bf8 %[[VR6:.*]][false]
160+
// GFX950: rocdl.cvt.scalef32.pk.bf16.bf8 %[[VR6]][true]
128161
%2 = tt.fp_to_fp %arg0 : tensor<8x8xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>> -> tensor<8x8xbf16, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>>
129162

130-
// GFX950-COUNT-4: rocdl.cvt.scalef32.pk.f32.fp8
163+
// GFX950: rocdl.cvt.scalef32.pk.f32.fp8 %[[VR7:.*]][false]
164+
// GFX950: rocdl.cvt.scalef32.pk.f32.fp8 %[[VR7]][true]
165+
// GFX950: rocdl.cvt.scalef32.pk.f32.fp8 %[[VR8:.*]][false]
166+
// GFX950: rocdl.cvt.scalef32.pk.f32.fp8 %[[VR8]][true]
131167
%3 = tt.fp_to_fp %arg1 : tensor<8x8xf8E4M3FN, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>> -> tensor<8x8xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>>
132168

133-
// GFX950-COUNT-4: rocdl.cvt.scalef32.pk.f16.fp8
169+
// GFX950: rocdl.cvt.scalef32.pk.f16.fp8 %[[VR9:.*]][false]
170+
// GFX950: rocdl.cvt.scalef32.pk.f16.fp8 %[[VR9]][true]
171+
// GFX950: rocdl.cvt.scalef32.pk.f16.fp8 %[[VR10:.*]][false]
172+
// GFX950: rocdl.cvt.scalef32.pk.f16.fp8 %[[VR10]][true]
134173
%4 = tt.fp_to_fp %arg1 : tensor<8x8xf8E4M3FN, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>> -> tensor<8x8xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>>
135174

136-
// GFX950-COUNT-4: rocdl.cvt.scalef32.pk.bf16.fp8
175+
// GFX950: rocdl.cvt.scalef32.pk.bf16.fp8 %[[VR11:.*]][false]
176+
// GFX950: rocdl.cvt.scalef32.pk.bf16.fp8 %[[VR11]][true]
177+
// GFX950: rocdl.cvt.scalef32.pk.bf16.fp8 %[[VR12:.*]][false]
178+
// GFX950: rocdl.cvt.scalef32.pk.bf16.fp8 %[[VR12]][true]
137179
%5 = tt.fp_to_fp %arg1 : tensor<8x8xf8E4M3FN, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>> -> tensor<8x8xbf16, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>>
138180

139-
// GFX942-COUNT-4: rocdl.cvt.pk.f32.bf8
181+
// GFX942: rocdl.cvt.pk.f32.bf8 %[[VR13:.*]][false]
182+
// GFX942: rocdl.cvt.pk.f32.bf8 %[[VR13]][true]
183+
// GFX942: rocdl.cvt.pk.f32.bf8 %[[VR14:.*]][false]
184+
// GFX942: rocdl.cvt.pk.f32.bf8 %[[VR14]][true]
140185
%6 = tt.fp_to_fp %arg2 : tensor<8x8xf8E5M2FNUZ, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>> -> tensor<8x8xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>>
141186

142-
// GFX942-COUNT-4: rocdl.cvt.pk.f32.fp8
187+
// GFX942: rocdl.cvt.pk.f32.fp8 %[[VR15:.*]][false]
188+
// GFX942: rocdl.cvt.pk.f32.fp8 %[[VR15]][true]
189+
// GFX942: rocdl.cvt.pk.f32.fp8 %[[VR16:.*]][false]
190+
// GFX942: rocdl.cvt.pk.f32.fp8 %[[VR16]][true]
143191
%7 = tt.fp_to_fp %arg3 : tensor<8x8xf8E4M3FNUZ, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>> -> tensor<8x8xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked2}>>
144192
tt.return
145193
}

0 commit comments

Comments
 (0)