@@ -19,70 +19,6 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
19
19
#mfma = #ttg.amd_mfma <{version = 3 , warpsPerCTA = [4 , 1 ], instrShape = [32 , 32 ], isTransposed = true }>
20
20
#dotop0 = #ttg.dot_op <{opIdx = 0 , parent = #mfma , kWidth =8 }>
21
21
22
- module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 4 : i32 , " ttg.threads-per-warp" = 64 : i32 } {
23
- // GFX942-LABEL: mfma_dot_cvt_f8_mfma32
24
- tt.func public @mfma_dot_cvt_f8_mfma32 (%arg0: tensor <128 x32 xf8 E4 M3 FNUZ, #mfma >) {
25
- // GFX942-NOT: store
26
- // GFX942-NOT: load
27
-
28
- // GFX942: [[val3:%.*]] = llvm.extractvalue %arg0[3]
29
- // GFX942: [[val7:%.*]] = llvm.extractvalue %arg0[7]
30
-
31
- // GFX942-DAG: [[c32:%.*]] = llvm.mlir.constant(32 : i32)
32
- // GFX942-DAG: [[c64:%.*]] = llvm.mlir.constant(64 : i32)
33
-
34
- // GFX942: [[threadId:%.*]] = rocdl.workitem.id.x
35
- // GFX942: [[c255:%.*]] = llvm.mlir.constant(255 : i32)
36
- // GFX942: [[RTID:%.*]] = llvm.and [[threadId]], [[c255]]
37
- // GFX942: [[laneId:%.*]] = llvm.urem [[RTID]], [[c64]]
38
- // GFX942: [[mask0:%.*]] = llvm.icmp "slt" [[laneId]], [[c32]]
39
-
40
- // GFX942: [[shflLaneId:%.*]] = llvm.add [[laneId]], [[c32]]
41
- // GFX942: [[addr32:%.*]] = llvm.urem [[shflLaneId]], [[c64]]
42
-
43
- // GFX942: [[vec0:%.*]] = llvm.insertelement [[val3]], {{.*}} : vector<4xi8>
44
- // GFX942: [[vec1:%.*]] = llvm.insertelement [[val7]], {{.*}} : vector<4xi8>
45
-
46
- // GFX942: [[bvec0:%.*]] = llvm.bitcast [[vec0]]
47
- // GFX942: [[c2:%.*]] = llvm.mlir.constant(2 : i32)
48
- // GFX942: [[addr:%.*]] = llvm.shl [[addr32]], [[c2]]
49
- // GFX942: [[bShflVec0:%.*]] = rocdl.ds_bpermute [[addr]], [[bvec0]]
50
- // GFX942: [[shflVec0:%.*]] = llvm.bitcast [[bShflVec0]]
51
-
52
- // GFX942: [[bvec1:%.*]] = llvm.bitcast [[vec1]]
53
- // GFX942: [[c2:%.*]] = llvm.mlir.constant(2 : i32)
54
- // GFX942: [[addr:%.*]] = llvm.shl [[addr32]], [[c2]]
55
- // GFX942: [[bShflVec1:%.*]] = rocdl.ds_bpermute [[addr]], [[bvec1]]
56
- // GFX942: [[shflVec1:%.*]] = llvm.bitcast [[bShflVec1]]
57
-
58
- // Input (8 values): (vec0, vec1)
59
- // Output (8 values shuffled, '>> n' - take the value from (lane + n) % 64):
60
- // resVec0 resVec1
61
- // lanes 0-31: (vec0 , vec0 >> 32) (mask0=1)
62
- // lanes 32-63: (vec1 >> 32, vec1 ) (mask0=0)
63
-
64
- // GFX942: [[resVec0:%.*]] = llvm.select [[mask0]], [[vec0]], [[shflVec1]]
65
- // GFX942: [[resVec1:%.*]] = llvm.select [[mask0]], [[shflVec0]], [[vec1]]
66
-
67
- // GFX942: [[c3:%.*]] = llvm.mlir.constant(3 : i32)
68
- // GFX942: [[resVal3:%.*]] = llvm.extractelement [[resVec0]][[[c3]] : i32] : vector<4xi8>
69
- // GFX942: [[c3:%.*]] = llvm.mlir.constant(3 : i32) : i32
70
- // GFX942: [[resVal7:%.*]] = llvm.extractelement [[resVec1]][[[c3]] : i32] : vector<4xi8>
71
-
72
- // GFX942: llvm.insertvalue [[resVal3]], {{.*}}[3]
73
- // GFX942: llvm.insertvalue [[resVal7]], {{.*}}[7]
74
-
75
- // GFX942: llvm.return
76
- %0 = ttg.convert_layout %arg0 : tensor <128 x32 xf8 E4 M3 FNUZ, #mfma > -> tensor <128 x32 xf8 E4 M3 FNUZ, #dotop0 >
77
- tt.return
78
- }
79
- }
80
-
81
- // -----
82
-
83
- #mfma = #ttg.amd_mfma <{version = 3 , warpsPerCTA = [4 , 1 ], instrShape = [32 , 32 ], isTransposed = true }>
84
- #dotop0 = #ttg.dot_op <{opIdx = 0 , parent = #mfma , kWidth =8 }>
85
-
86
22
module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 4 : i32 , " ttg.threads-per-warp" = 64 : i32 } {
87
23
// GFX942-LABEL: mfma_dot_cvt_bf8_mfma32
88
24
tt.func public @mfma_dot_cvt_bf8_mfma32 (%arg0: tensor <128 x32 xf8 E5 M2 , #mfma >) {
@@ -100,100 +36,6 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
100
36
#mfma = #ttg.amd_mfma <{version = 3 , warpsPerCTA = [4 , 1 ], instrShape = [16 , 16 ], isTransposed = true }>
101
37
#dotop0 = #ttg.dot_op <{opIdx = 0 , parent = #mfma , kWidth =8 }>
102
38
103
- module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 4 : i32 , " ttg.threads-per-warp" = 64 : i32 } {
104
- // GFX942-LABEL: mfma_dot_cvt_f8_mfma16
105
- tt.func public @mfma_dot_cvt_f8_mfma16 (%arg0: tensor <128 x32 xf8 E4 M3 FNUZ, #mfma >) {
106
- // GFX942-NOT: store
107
- // GFX942-NOT: load
108
-
109
- // GFX942: [[val3:%.*]] = llvm.extractvalue %arg0[3]
110
- // GFX942: [[val7:%.*]] = llvm.extractvalue %arg0[7]
111
-
112
- // GFX942-DAG: [[c16:%.*]] = llvm.mlir.constant(16 : i32)
113
- // GFX942-DAG: [[c32:%.*]] = llvm.mlir.constant(32 : i32)
114
- // GFX942-DAG: [[c48:%.*]] = llvm.mlir.constant(48 : i32)
115
- // GFX942-DAG: [[c64:%.*]] = llvm.mlir.constant(64 : i32)
116
-
117
- // GFX942: [[threadId:%.*]] = rocdl.workitem.id.x
118
- // GFX942: [[c255:%.*]] = llvm.mlir.constant(255 : i32)
119
- // GFX942: [[RTID:%.*]] = llvm.and [[threadId]], [[c255]]
120
- // GFX942: [[laneId:%.*]] = llvm.urem [[RTID]], [[c64]]
121
- // GFX942: [[mask0:%.*]] = llvm.icmp "slt" [[laneId]], [[c32]]
122
-
123
- // GFX942: [[laneIdRem:%.*]] = llvm.urem [[laneId]], [[c32]]
124
- // GFX942: [[mask1:%.*]] = llvm.icmp "slt" [[laneIdRem]], [[c16]]
125
-
126
- // GFX942: [[shflLaneId:%.*]] = llvm.add [[laneId]], [[c16]]
127
- // GFX942: [[addr16:%.*]] = llvm.urem [[shflLaneId]], [[c64]]
128
-
129
- // GFX942: [[shflLaneId:%.*]] = llvm.add [[laneId]], [[c32]]
130
- // GFX942: [[addr32:%.*]] = llvm.urem [[shflLaneId]], [[c64]]
131
-
132
- // GFX942: [[shflLaneId:%.*]] = llvm.add [[laneId]], [[c48]]
133
- // GFX942: [[addr48:%.*]] = llvm.urem [[shflLaneId]], [[c64]]
134
-
135
- // GFX942: [[vec0:%.*]] = llvm.insertelement [[val3]], {{.*}} : vector<4xi8>
136
- // GFX942: [[vec1:%.*]] = llvm.insertelement [[val7]], {{.*}} : vector<4xi8>
137
-
138
- // GFX942: [[bvec0:%.*]] = llvm.bitcast [[vec0]]
139
- // GFX942: [[c2:%.*]] = llvm.mlir.constant(2 : i32)
140
- // GFX942: [[addr:%.*]] = llvm.shl [[addr16]], [[c2]]
141
- // GFX942: [[bShflVec0_16:%.*]] = rocdl.ds_bpermute [[addr]], [[bvec0]]
142
- // GFX942: [[shflVec0_16:%.*]] = llvm.bitcast [[bShflVec0_16]]
143
-
144
- // GFX942: [[bvec0:%.*]] = llvm.bitcast [[vec0]]
145
- // GFX942: [[c2:%.*]] = llvm.mlir.constant(2 : i32)
146
- // GFX942: [[addr:%.*]] = llvm.shl [[addr32]], [[c2]]
147
- // GFX942: [[bShflVec0_32:%.*]] = rocdl.ds_bpermute [[addr]], [[bvec0]]
148
- // GFX942: [[shflVec0_32:%.*]] = llvm.bitcast [[bShflVec0_32]]
149
-
150
- // GFX942: [[bvec1:%.*]] = llvm.bitcast [[vec1]]
151
- // GFX942: [[c2:%.*]] = llvm.mlir.constant(2 : i32)
152
- // GFX942: [[addr:%.*]] = llvm.shl [[addr32]], [[c2]]
153
- // GFX942: [[bShflVec1_32:%.*]] = rocdl.ds_bpermute [[addr]], [[bvec1]]
154
- // GFX942: [[shflVec1_32:%.*]] = llvm.bitcast [[bShflVec1_32]]
155
-
156
- // GFX942: [[bvec1:%.*]] = llvm.bitcast [[vec1]]
157
- // GFX942: [[c2:%.*]] = llvm.mlir.constant(2 : i32)
158
- // GFX942: [[addr:%.*]] = llvm.shl [[addr48]], [[c2]]
159
- // GFX942: [[bShflVec1_48:%.*]] = rocdl.ds_bpermute [[addr]], [[bvec1]]
160
- // GFX942: [[shflVec1_48:%.*]] = llvm.bitcast [[bShflVec1_48]]
161
-
162
- // Input (8 values): (vec0, vec1)
163
- // Output (8 values shuffled, '>> n' - take the value from (lane + n) % 64):
164
- // resVec0 resVec1
165
- // lanes 0-15: (vec0 , vec0 >> 16) (mask0=1, mask1=1)
166
- // lanes 16-31: (vec0 >> 16, vec0 >> 32) (mask0=1, mask1=0)
167
- // lanes 32-47: (vec1 >> 32, vec1 >> 48) (mask0=0, mask1=1)
168
- // lanes 48-63: (vec1 >> 48, vec1 ) (mask0=0, mask1=0)
169
-
170
- // GFX942-DAG: [[mask0_true:%.*]] = llvm.select [[mask1]], [[vec0]], [[shflVec0_16]] : i1, vector<4xi8>
171
- // GFX942-DAG: [[mask0_false:%.*]] = llvm.select [[mask1]], [[shflVec1_32]], [[shflVec1_48]] : i1, vector<4xi8>
172
- // GFX942: [[resVec0:%.*]] = llvm.select [[mask0]], [[mask0_true]], [[mask0_false]] : i1, vector<4xi8>
173
-
174
- // GFX942-DAG: [[mask0_true:%.*]] = llvm.select [[mask1]], [[shflVec0_16]], [[shflVec0_32]] : i1, vector<4xi8>
175
- // GFX942-DAG: [[mask0_false:%.*]] = llvm.select [[mask1]], [[shflVec1_48]], [[vec1]] : i1, vector<4xi8>
176
- // GFX942: [[resVec1:%.*]] = llvm.select [[mask0]], [[mask0_true]], [[mask0_false]] : i1, vector<4xi8>
177
-
178
- // GFX942: [[c3:%.*]] = llvm.mlir.constant(3 : i32)
179
- // GFX942: [[resVal3:%.*]] = llvm.extractelement [[resVec0]][[[c3]] : i32] : vector<4xi8>
180
- // GFX942: [[c3:%.*]] = llvm.mlir.constant(3 : i32) : i32
181
- // GFX942: [[resVal7:%.*]] = llvm.extractelement [[resVec1]][[[c3]] : i32] : vector<4xi8>
182
-
183
- // GFX942: llvm.insertvalue [[resVal3]], {{.*}}[3]
184
- // GFX942: llvm.insertvalue [[resVal7]], {{.*}}[7]
185
-
186
- // GFX942: llvm.return
187
- %0 = ttg.convert_layout %arg0 : tensor <128 x32 xf8 E4 M3 FNUZ, #mfma > -> tensor <128 x32 xf8 E4 M3 FNUZ, #dotop0 >
188
- tt.return
189
- }
190
- }
191
-
192
- // -----
193
-
194
- #mfma = #ttg.amd_mfma <{version = 3 , warpsPerCTA = [4 , 1 ], instrShape = [16 , 16 ], isTransposed = true }>
195
- #dotop0 = #ttg.dot_op <{opIdx = 0 , parent = #mfma , kWidth =8 }>
196
-
197
39
module attributes {" ttg.num-ctas" = 1 : i32 , " ttg.num-warps" = 4 : i32 , " ttg.threads-per-warp" = 64 : i32 } {
198
40
// GFX942-LABEL: mfma_dot_cvt_bf8_mfma16
199
41
tt.func public @mfma_dot_cvt_bf8_mfma16 (%arg0: tensor <128 x32 xf8 E5 M2 , #mfma >) {
0 commit comments