Skip to content

Commit c328c5d

Browse files
rampiteckosarev
andauthored
[AMDGPU] Combine to bf16 reciprocal square root. (#154185)
Co-authored-by: Ivan Kosarev <[email protected]> Co-authored-by: Ivan Kosarev <[email protected]>
1 parent b20bbd4 commit c328c5d

File tree

2 files changed

+28
-47
lines changed

2 files changed

+28
-47
lines changed

llvm/lib/Target/AMDGPU/SIISelLowering.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15729,7 +15729,7 @@ SDValue SITargetLowering::performFDivCombine(SDNode *N,
1572915729
SelectionDAG &DAG = DCI.DAG;
1573015730
SDLoc SL(N);
1573115731
EVT VT = N->getValueType(0);
15732-
if (VT != MVT::f16 || !Subtarget->has16BitInsts())
15732+
if ((VT != MVT::f16 && VT != MVT::bf16) || !Subtarget->has16BitInsts())
1573315733
return SDValue();
1573415734

1573515735
SDValue LHS = N->getOperand(0);

llvm/test/CodeGen/AMDGPU/fdiv.bf16.ll

Lines changed: 27 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -82,78 +82,69 @@ define bfloat @v_rcp_bf16_neg(bfloat %x) {
8282
ret bfloat %fdiv
8383
}
8484

85-
; TODO: Support lowering to v_rsq_bf16.
8685
define bfloat @v_rsq_bf16(bfloat %x) {
8786
; GFX1250-TRUE16-LABEL: v_rsq_bf16:
8887
; GFX1250-TRUE16: ; %bb.0:
8988
; GFX1250-TRUE16-NEXT: s_wait_loadcnt_dscnt 0x0
9089
; GFX1250-TRUE16-NEXT: s_wait_kmcnt 0x0
91-
; GFX1250-TRUE16-NEXT: v_sqrt_bf16_e32 v0.l, v0.l
92-
; GFX1250-TRUE16-NEXT: s_delay_alu instid0(TRANS32_DEP_1)
93-
; GFX1250-TRUE16-NEXT: v_rcp_bf16_e32 v0.l, v0.l
90+
; GFX1250-TRUE16-NEXT: v_rsq_bf16_e32 v0.l, v0.l
9491
; GFX1250-TRUE16-NEXT: s_set_pc_i64 s[30:31]
9592
;
9693
; GFX1250-FAKE16-LABEL: v_rsq_bf16:
9794
; GFX1250-FAKE16: ; %bb.0:
9895
; GFX1250-FAKE16-NEXT: s_wait_loadcnt_dscnt 0x0
9996
; GFX1250-FAKE16-NEXT: s_wait_kmcnt 0x0
100-
; GFX1250-FAKE16-NEXT: v_sqrt_bf16_e32 v0, v0
101-
; GFX1250-FAKE16-NEXT: s_delay_alu instid0(TRANS32_DEP_1)
102-
; GFX1250-FAKE16-NEXT: v_rcp_bf16_e32 v0, v0
97+
; GFX1250-FAKE16-NEXT: v_rsq_bf16_e32 v0, v0
10398
; GFX1250-FAKE16-NEXT: s_set_pc_i64 s[30:31]
10499
%sqrt = call contract bfloat @llvm.sqrt.bf16(bfloat %x)
105100
%fdiv = fdiv contract bfloat 1.0, %sqrt
106101
ret bfloat %fdiv
107102
}
108103

109-
; TODO: Support lowering to v_rsq_bf16.
110104
define bfloat @v_rsq_bf16_neg(bfloat %x) {
111105
; GFX1250-TRUE16-LABEL: v_rsq_bf16_neg:
112106
; GFX1250-TRUE16: ; %bb.0:
113107
; GFX1250-TRUE16-NEXT: s_wait_loadcnt_dscnt 0x0
114108
; GFX1250-TRUE16-NEXT: s_wait_kmcnt 0x0
115-
; GFX1250-TRUE16-NEXT: v_sqrt_bf16_e32 v0.l, v0.l
109+
; GFX1250-TRUE16-NEXT: v_rsq_bf16_e32 v0.l, v0.l
110+
; GFX1250-TRUE16-NEXT: v_nop
116111
; GFX1250-TRUE16-NEXT: s_delay_alu instid0(TRANS32_DEP_1)
117-
; GFX1250-TRUE16-NEXT: v_rcp_bf16_e64 v0.l, -v0.l
112+
; GFX1250-TRUE16-NEXT: v_xor_b16 v0.l, 0x8000, v0.l
118113
; GFX1250-TRUE16-NEXT: s_set_pc_i64 s[30:31]
119114
;
120115
; GFX1250-FAKE16-LABEL: v_rsq_bf16_neg:
121116
; GFX1250-FAKE16: ; %bb.0:
122117
; GFX1250-FAKE16-NEXT: s_wait_loadcnt_dscnt 0x0
123118
; GFX1250-FAKE16-NEXT: s_wait_kmcnt 0x0
124-
; GFX1250-FAKE16-NEXT: v_sqrt_bf16_e32 v0, v0
119+
; GFX1250-FAKE16-NEXT: v_rsq_bf16_e32 v0, v0
120+
; GFX1250-FAKE16-NEXT: v_nop
125121
; GFX1250-FAKE16-NEXT: s_delay_alu instid0(TRANS32_DEP_1)
126-
; GFX1250-FAKE16-NEXT: v_rcp_bf16_e64 v0, -v0
122+
; GFX1250-FAKE16-NEXT: v_xor_b32_e32 v0, 0x8000, v0
127123
; GFX1250-FAKE16-NEXT: s_set_pc_i64 s[30:31]
128124
%sqrt = call contract bfloat @llvm.sqrt.bf16(bfloat %x)
129125
%fdiv = fdiv contract bfloat -1.0, %sqrt
130126
ret bfloat %fdiv
131127
}
132128

133-
; TODO: Support lowering to v_rsq_bf16.
134129
define <2 x bfloat> @v_rsq_bf16_multi_use(bfloat %x) {
135130
; GFX1250-TRUE16-LABEL: v_rsq_bf16_multi_use:
136131
; GFX1250-TRUE16: ; %bb.0:
137132
; GFX1250-TRUE16-NEXT: s_wait_loadcnt_dscnt 0x0
138133
; GFX1250-TRUE16-NEXT: s_wait_kmcnt 0x0
139134
; GFX1250-TRUE16-NEXT: v_mov_b16_e32 v1.l, v0.l
140-
; GFX1250-TRUE16-NEXT: s_delay_alu instid0(VALU_DEP_1) | instskip(NEXT) | instid1(TRANS32_DEP_1)
141-
; GFX1250-TRUE16-NEXT: v_sqrt_bf16_e32 v1.l, v1.l
142-
; GFX1250-TRUE16-NEXT: v_rcp_bf16_e32 v1.h, v1.l
135+
; GFX1250-TRUE16-NEXT: s_delay_alu instid0(VALU_DEP_1) | instskip(SKIP_1) | instid1(TRANS32_DEP_1)
136+
; GFX1250-TRUE16-NEXT: v_rsq_bf16_e32 v1.h, v1.l
143137
; GFX1250-TRUE16-NEXT: v_nop
144-
; GFX1250-TRUE16-NEXT: v_mov_b16_e32 v1.l, v0.l
145-
; GFX1250-TRUE16-NEXT: s_delay_alu instid0(TRANS32_DEP_1) | instid1(VALU_DEP_1)
146138
; GFX1250-TRUE16-NEXT: v_mov_b32_e32 v0, v1
147139
; GFX1250-TRUE16-NEXT: s_set_pc_i64 s[30:31]
148140
;
149141
; GFX1250-FAKE16-LABEL: v_rsq_bf16_multi_use:
150142
; GFX1250-FAKE16: ; %bb.0:
151143
; GFX1250-FAKE16-NEXT: s_wait_loadcnt_dscnt 0x0
152144
; GFX1250-FAKE16-NEXT: s_wait_kmcnt 0x0
153-
; GFX1250-FAKE16-NEXT: v_sqrt_bf16_e32 v1, v0
154-
; GFX1250-FAKE16-NEXT: s_delay_alu instid0(TRANS32_DEP_1) | instskip(SKIP_1) | instid1(TRANS32_DEP_1)
155-
; GFX1250-FAKE16-NEXT: v_rcp_bf16_e32 v1, v1
145+
; GFX1250-FAKE16-NEXT: v_rsq_bf16_e32 v1, v0
156146
; GFX1250-FAKE16-NEXT: v_nop
147+
; GFX1250-FAKE16-NEXT: s_delay_alu instid0(TRANS32_DEP_1)
157148
; GFX1250-FAKE16-NEXT: v_perm_b32 v0, v1, v0, 0x5040100
158149
; GFX1250-FAKE16-NEXT: s_set_pc_i64 s[30:31]
159150
%sqrt = call contract bfloat @llvm.sqrt.bf16(bfloat %x)
@@ -163,7 +154,6 @@ define <2 x bfloat> @v_rsq_bf16_multi_use(bfloat %x) {
163154
ret <2 x bfloat> %r2
164155
}
165156

166-
; TODO: Support lowering to v_rsq_bf16.
167157
define bfloat @v_rsq_bf16_missing_contract0(bfloat %x) {
168158
; GFX1250-TRUE16-LABEL: v_rsq_bf16_missing_contract0:
169159
; GFX1250-TRUE16: ; %bb.0:
@@ -187,7 +177,6 @@ define bfloat @v_rsq_bf16_missing_contract0(bfloat %x) {
187177
ret bfloat %fdiv
188178
}
189179

190-
; TODO: Support lowering to v_rsq_bf16.
191180
define bfloat @v_rsq_bf16_missing_contract1(bfloat %x) {
192181
; GFX1250-TRUE16-LABEL: v_rsq_bf16_missing_contract1:
193182
; GFX1250-TRUE16: ; %bb.0:
@@ -211,7 +200,6 @@ define bfloat @v_rsq_bf16_missing_contract1(bfloat %x) {
211200
ret bfloat %fdiv
212201
}
213202

214-
; TODO: Support lowering to v_rsq_bf16.
215203
define bfloat @v_neg_rsq_bf16_missing_contract1(bfloat %x) {
216204
; GFX1250-TRUE16-LABEL: v_neg_rsq_bf16_missing_contract1:
217205
; GFX1250-TRUE16: ; %bb.0:
@@ -240,24 +228,18 @@ define <2 x bfloat> @v_rsq_v2bf16(<2 x bfloat> %a) {
240228
; GFX1250-TRUE16: ; %bb.0:
241229
; GFX1250-TRUE16-NEXT: s_wait_loadcnt_dscnt 0x0
242230
; GFX1250-TRUE16-NEXT: s_wait_kmcnt 0x0
243-
; GFX1250-TRUE16-NEXT: v_sqrt_bf16_e32 v0.h, v0.h
244-
; GFX1250-TRUE16-NEXT: v_sqrt_bf16_e32 v0.l, v0.l
245-
; GFX1250-TRUE16-NEXT: s_delay_alu instid0(TRANS32_DEP_2) | instskip(NEXT) | instid1(TRANS32_DEP_2)
246-
; GFX1250-TRUE16-NEXT: v_rcp_bf16_e32 v0.h, v0.h
247-
; GFX1250-TRUE16-NEXT: v_rcp_bf16_e32 v0.l, v0.l
231+
; GFX1250-TRUE16-NEXT: v_rsq_bf16_e32 v0.h, v0.h
232+
; GFX1250-TRUE16-NEXT: v_rsq_bf16_e32 v0.l, v0.l
248233
; GFX1250-TRUE16-NEXT: s_set_pc_i64 s[30:31]
249234
;
250235
; GFX1250-FAKE16-LABEL: v_rsq_v2bf16:
251236
; GFX1250-FAKE16: ; %bb.0:
252237
; GFX1250-FAKE16-NEXT: s_wait_loadcnt_dscnt 0x0
253238
; GFX1250-FAKE16-NEXT: s_wait_kmcnt 0x0
254239
; GFX1250-FAKE16-NEXT: v_lshrrev_b32_e32 v1, 16, v0
255-
; GFX1250-FAKE16-NEXT: v_sqrt_bf16_e32 v0, v0
256-
; GFX1250-FAKE16-NEXT: s_delay_alu instid0(VALU_DEP_1) | instskip(NEXT) | instid1(TRANS32_DEP_2)
257-
; GFX1250-FAKE16-NEXT: v_sqrt_bf16_e32 v1, v1
258-
; GFX1250-FAKE16-NEXT: v_rcp_bf16_e32 v0, v0
259-
; GFX1250-FAKE16-NEXT: s_delay_alu instid0(TRANS32_DEP_2) | instskip(SKIP_1) | instid1(TRANS32_DEP_1)
260-
; GFX1250-FAKE16-NEXT: v_rcp_bf16_e32 v1, v1
240+
; GFX1250-FAKE16-NEXT: v_rsq_bf16_e32 v0, v0
241+
; GFX1250-FAKE16-NEXT: s_delay_alu instid0(VALU_DEP_1) | instskip(SKIP_1) | instid1(TRANS32_DEP_1)
242+
; GFX1250-FAKE16-NEXT: v_rsq_bf16_e32 v1, v1
261243
; GFX1250-FAKE16-NEXT: v_nop
262244
; GFX1250-FAKE16-NEXT: v_perm_b32 v0, v1, v0, 0x5040100
263245
; GFX1250-FAKE16-NEXT: s_set_pc_i64 s[30:31]
@@ -271,25 +253,24 @@ define <2 x bfloat> @v_neg_rsq_v2bf16(<2 x bfloat> %a) {
271253
; GFX1250-TRUE16: ; %bb.0:
272254
; GFX1250-TRUE16-NEXT: s_wait_loadcnt_dscnt 0x0
273255
; GFX1250-TRUE16-NEXT: s_wait_kmcnt 0x0
274-
; GFX1250-TRUE16-NEXT: v_sqrt_bf16_e32 v0.h, v0.h
275-
; GFX1250-TRUE16-NEXT: v_sqrt_bf16_e32 v0.l, v0.l
276-
; GFX1250-TRUE16-NEXT: s_delay_alu instid0(TRANS32_DEP_2) | instskip(NEXT) | instid1(TRANS32_DEP_2)
277-
; GFX1250-TRUE16-NEXT: v_rcp_bf16_e64 v0.h, -v0.h
278-
; GFX1250-TRUE16-NEXT: v_rcp_bf16_e64 v0.l, -v0.l
256+
; GFX1250-TRUE16-NEXT: v_rsq_bf16_e32 v0.h, v0.h
257+
; GFX1250-TRUE16-NEXT: v_rsq_bf16_e32 v0.l, v0.l
258+
; GFX1250-TRUE16-NEXT: s_delay_alu instid0(TRANS32_DEP_2) | instskip(NEXT) | instid1(TRANS32_DEP_1)
259+
; GFX1250-TRUE16-NEXT: v_xor_b16 v0.h, 0x8000, v0.h
260+
; GFX1250-TRUE16-NEXT: v_xor_b16 v0.l, 0x8000, v0.l
279261
; GFX1250-TRUE16-NEXT: s_set_pc_i64 s[30:31]
280262
;
281263
; GFX1250-FAKE16-LABEL: v_neg_rsq_v2bf16:
282264
; GFX1250-FAKE16: ; %bb.0:
283265
; GFX1250-FAKE16-NEXT: s_wait_loadcnt_dscnt 0x0
284266
; GFX1250-FAKE16-NEXT: s_wait_kmcnt 0x0
285267
; GFX1250-FAKE16-NEXT: v_lshrrev_b32_e32 v1, 16, v0
286-
; GFX1250-FAKE16-NEXT: v_sqrt_bf16_e32 v0, v0
268+
; GFX1250-FAKE16-NEXT: v_rsq_bf16_e32 v0, v0
287269
; GFX1250-FAKE16-NEXT: s_delay_alu instid0(VALU_DEP_1) | instskip(NEXT) | instid1(TRANS32_DEP_2)
288-
; GFX1250-FAKE16-NEXT: v_sqrt_bf16_e32 v1, v1
289-
; GFX1250-FAKE16-NEXT: v_rcp_bf16_e64 v0, -v0
290-
; GFX1250-FAKE16-NEXT: s_delay_alu instid0(TRANS32_DEP_2) | instskip(SKIP_1) | instid1(TRANS32_DEP_1)
291-
; GFX1250-FAKE16-NEXT: v_rcp_bf16_e64 v1, -v1
292-
; GFX1250-FAKE16-NEXT: v_nop
270+
; GFX1250-FAKE16-NEXT: v_rsq_bf16_e32 v1, v1
271+
; GFX1250-FAKE16-NEXT: v_xor_b32_e32 v0, 0x8000, v0
272+
; GFX1250-FAKE16-NEXT: s_delay_alu instid0(TRANS32_DEP_1) | instskip(NEXT) | instid1(VALU_DEP_1)
273+
; GFX1250-FAKE16-NEXT: v_xor_b32_e32 v1, 0x8000, v1
293274
; GFX1250-FAKE16-NEXT: v_perm_b32 v0, v1, v0, 0x5040100
294275
; GFX1250-FAKE16-NEXT: s_set_pc_i64 s[30:31]
295276
%sqrt = call contract <2 x bfloat> @llvm.sqrt.v2bf16(<2 x bfloat> %a)

0 commit comments

Comments
 (0)