Skip to content

Commit 323eca9

Browse files
ViacheslavRbigcbot
authored andcommitted
Support NaN in Bfloat MinMax resolution
Fixing BfloatFuncsResolution pass to support NaN in MinMax resolution.
1 parent c3e6c9d commit 323eca9

File tree

3 files changed

+149
-61
lines changed

3 files changed

+149
-61
lines changed

IGC/Compiler/Optimizer/OpenCLPasses/BfloatFuncs/BfloatFuncsResolution.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,9 +231,16 @@ void BfloatFuncsResolution::handleMinMax(CallInst &CI,
231231
auto Op0Bf = bitcastToBfloat(Op0);
232232
auto Op1Bf = bitcastToBfloat(Op1);
233233

234+
// According to OpenCL C spec:
235+
// If one argument is a NaN, fmax() or fmin() returns the other argument.
236+
// If both arguments are NaNs, fmax() or fmin() returns a NaN.
234237
auto CompareInst = m_builder->CreateFCmp(Pred, Op0Bf, Op1Bf);
235238
auto SelectInst = m_builder->CreateSelect(CompareInst, Op0Bf, Op1Bf);
236-
auto Res = m_builder->CreateBitCast(SelectInst, CI.getType());
239+
auto IsNaNOp0 = m_builder->CreateFCmp(CmpInst::Predicate::FCMP_UNO, Op0Bf, Op0Bf);
240+
auto OtherVal = m_builder->CreateSelect(IsNaNOp0, Op1Bf, Op0Bf);
241+
auto CompareNaNInst = m_builder->CreateFCmp(CmpInst::Predicate::FCMP_ORD, Op0Bf, Op1Bf);
242+
auto SelectInst3 = m_builder->CreateSelect(CompareNaNInst, SelectInst, OtherVal);
243+
auto Res = m_builder->CreateBitCast(SelectInst3, CI.getType());
237244

238245
CI.replaceAllUsesWith(Res);
239246
m_instructionsToRemove.push_back(&CI);

IGC/Compiler/tests/BfloatFuncsResolution/min_max-typed-pointers.ll

Lines changed: 70 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -50,45 +50,65 @@ define spir_kernel void @test_min(i16 addrspace(1)* %out1, i16 zeroext %v1_1, i1
5050
entry:
5151
; CHECK: %[[SRC0BF:.*]] = bitcast i16 %v1_1 to bfloat
5252
; CHECK: %[[SRC1BF:.*]] = bitcast i16 %v2_1 to bfloat
53-
; CHECK: %[[COND:.*]] = fcmp olt bfloat %[[SRC0BF]], %[[SRC1BF]]
54-
; CHECK: %[[SELECTRES:.*]] = select i1 %[[COND]], bfloat %[[SRC0BF]], bfloat %[[SRC1BF]]
55-
; CHECK: %{{.*}} = bitcast bfloat %[[SELECTRES]] to i16
53+
; CHECK: %[[COND1:.*]] = fcmp olt bfloat %[[SRC0BF]], %[[SRC1BF]]
54+
; CHECK: %[[SELECT1RES:.*]] = select i1 %[[COND1]], bfloat %[[SRC0BF]], bfloat %[[SRC1BF]]
55+
; CHECK: %[[COND2:.*]] = fcmp uno bfloat %[[SRC0BF]], %[[SRC0BF]]
56+
; CHECK: %[[SELECT2RES:.*]] = select i1 %[[COND2]], bfloat %[[SRC1BF]], bfloat %[[SRC0BF]]
57+
; CHECK: %[[COND3:.*]] = fcmp ord bfloat %[[SRC0BF]], %[[SRC1BF]]
58+
; CHECK: %[[SELECT3RES:.*]] = select i1 %[[COND3]], bfloat %[[SELECT1RES]], bfloat %[[SELECT2RES]]
59+
; CHECK: %{{.*}} = bitcast bfloat %[[SELECT3RES]] to i16
5660
%call = call spir_func zeroext i16 @_Z18__builtin_bf16_mintt(i16 zeroext %v1_1, i16 zeroext %v2_1) #2
5761
%arrayidx = getelementptr inbounds i16, i16 addrspace(1)* %out1, i64 0
5862
store i16 %call, i16 addrspace(1)* %arrayidx, align 2
5963

6064
; CHECK: %[[SRC0BF:.*]] = bitcast <2 x i16> %v1_2 to <2 x bfloat>
6165
; CHECK: %[[SRC1BF:.*]] = bitcast <2 x i16> %v2_2 to <2 x bfloat>
62-
; CHECK: %[[COND:.*]] = fcmp olt <2 x bfloat> %[[SRC0BF]], %[[SRC1BF]]
63-
; CHECK: %[[SELECTRES:.*]] = select <2 x i1> %[[COND]], <2 x bfloat> %[[SRC0BF]], <2 x bfloat> %[[SRC1BF]]
64-
; CHECK: %{{.*}} = bitcast <2 x bfloat> %[[SELECTRES]] to <2 x i16>
66+
; CHECK: %[[COND1:.*]] = fcmp olt <2 x bfloat> %[[SRC0BF]], %[[SRC1BF]]
67+
; CHECK: %[[SELECT1RES:.*]] = select <2 x i1> %[[COND1]], <2 x bfloat> %[[SRC0BF]], <2 x bfloat> %[[SRC1BF]]
68+
; CHECK: %[[COND2:.*]] = fcmp uno <2 x bfloat> %[[SRC0BF]], %[[SRC0BF]]
69+
; CHECK: %[[SELECT2RES:.*]] = select <2 x i1> %[[COND2]], <2 x bfloat> %[[SRC1BF]], <2 x bfloat> %[[SRC0BF]]
70+
; CHECK: %[[COND3:.*]] = fcmp ord <2 x bfloat> %[[SRC0BF]], %[[SRC1BF]]
71+
; CHECK: %[[SELECT3RES:.*]] = select <2 x i1> %[[COND3]], <2 x bfloat> %[[SELECT1RES]], <2 x bfloat> %[[SELECT2RES]]
72+
; CHECK: %{{.*}} = bitcast <2 x bfloat> %[[SELECT3RES]] to <2 x i16>
6573
%call1 = call spir_func <2 x i16> @_Z18__builtin_bf16_minDv2_tS_(<2 x i16> %v1_2, <2 x i16> %v2_2) #2
6674
%arrayidx2 = getelementptr inbounds <2 x i16>, <2 x i16> addrspace(1)* %out2, i64 1
6775
store <2 x i16> %call1, <2 x i16> addrspace(1)* %arrayidx2, align 4
6876

6977
; CHECK: %[[SRC0BF:.*]] = bitcast <4 x i16> %v1_4 to <4 x bfloat>
7078
; CHECK: %[[SRC1BF:.*]] = bitcast <4 x i16> %v2_4 to <4 x bfloat>
71-
; CHECK: %[[COND:.*]] = fcmp olt <4 x bfloat> %[[SRC0BF]], %[[SRC1BF]]
72-
; CHECK: %[[SELECTRES:.*]] = select <4 x i1> %[[COND]], <4 x bfloat> %[[SRC0BF]], <4 x bfloat> %[[SRC1BF]]
73-
; CHECK: %{{.*}} = bitcast <4 x bfloat> %[[SELECTRES]] to <4 x i16>
79+
; CHECK: %[[COND1:.*]] = fcmp olt <4 x bfloat> %[[SRC0BF]], %[[SRC1BF]]
80+
; CHECK: %[[SELECT1RES:.*]] = select <4 x i1> %[[COND1]], <4 x bfloat> %[[SRC0BF]], <4 x bfloat> %[[SRC1BF]]
81+
; CHECK: %[[COND2:.*]] = fcmp uno <4 x bfloat> %[[SRC0BF]], %[[SRC0BF]]
82+
; CHECK: %[[SELECT2RES:.*]] = select <4 x i1> %[[COND2]], <4 x bfloat> %[[SRC1BF]], <4 x bfloat> %[[SRC0BF]]
83+
; CHECK: %[[COND3:.*]] = fcmp ord <4 x bfloat> %[[SRC0BF]], %[[SRC1BF]]
84+
; CHECK: %[[SELECT3RES:.*]] = select <4 x i1> %[[COND3]], <4 x bfloat> %[[SELECT1RES]], <4 x bfloat> %[[SELECT2RES]]
85+
; CHECK: %{{.*}} = bitcast <4 x bfloat> %[[SELECT3RES]] to <4 x i16>
7486
%call3 = call spir_func <4 x i16> @_Z18__builtin_bf16_minDv4_tS_(<4 x i16> %v1_4, <4 x i16> %v2_4) #2
7587
%arrayidx4 = getelementptr inbounds <4 x i16>, <4 x i16> addrspace(1)* %out4, i64 2
7688
store <4 x i16> %call3, <4 x i16> addrspace(1)* %arrayidx4, align 8
7789

7890
; CHECK: %[[SRC0BF:.*]] = bitcast <8 x i16> %v1_8 to <8 x bfloat>
7991
; CHECK: %[[SRC1BF:.*]] = bitcast <8 x i16> %v2_8 to <8 x bfloat>
80-
; CHECK: %[[COND:.*]] = fcmp olt <8 x bfloat> %[[SRC0BF]], %[[SRC1BF]]
81-
; CHECK: %[[SELECTRES:.*]] = select <8 x i1> %[[COND]], <8 x bfloat> %[[SRC0BF]], <8 x bfloat> %[[SRC1BF]]
82-
; CHECK: %{{.*}} = bitcast <8 x bfloat> %[[SELECTRES]] to <8 x i16>
92+
; CHECK: %[[COND1:.*]] = fcmp olt <8 x bfloat> %[[SRC0BF]], %[[SRC1BF]]
93+
; CHECK: %[[SELECT1RES:.*]] = select <8 x i1> %[[COND1]], <8 x bfloat> %[[SRC0BF]], <8 x bfloat> %[[SRC1BF]]
94+
; CHECK: %[[COND2:.*]] = fcmp uno <8 x bfloat> %[[SRC0BF]], %[[SRC0BF]]
95+
; CHECK: %[[SELECT2RES:.*]] = select <8 x i1> %[[COND2]], <8 x bfloat> %[[SRC1BF]], <8 x bfloat> %[[SRC0BF]]
96+
; CHECK: %[[COND3:.*]] = fcmp ord <8 x bfloat> %[[SRC0BF]], %[[SRC1BF]]
97+
; CHECK: %[[SELECT3RES:.*]] = select <8 x i1> %[[COND3]], <8 x bfloat> %[[SELECT1RES]], <8 x bfloat> %[[SELECT2RES]]
98+
; CHECK: %{{.*}} = bitcast <8 x bfloat> %[[SELECT3RES]] to <8 x i16>
8399
%call5 = call spir_func <8 x i16> @_Z18__builtin_bf16_minDv8_tS_(<8 x i16> %v1_8, <8 x i16> %v2_8) #2
84100
%arrayidx6 = getelementptr inbounds <8 x i16>, <8 x i16> addrspace(1)* %out8, i64 3
85101
store <8 x i16> %call5, <8 x i16> addrspace(1)* %arrayidx6, align 16
86102

87103
; CHECK: %[[SRC0BF:.*]] = bitcast <16 x i16> %v1_16 to <16 x bfloat>
88104
; CHECK: %[[SRC1BF:.*]] = bitcast <16 x i16> %v2_16 to <16 x bfloat>
89-
; CHECK: %[[COND:.*]] = fcmp olt <16 x bfloat> %[[SRC0BF]], %[[SRC1BF]]
90-
; CHECK: %[[SELECTRES:.*]] = select <16 x i1> %[[COND]], <16 x bfloat> %[[SRC0BF]], <16 x bfloat> %[[SRC1BF]]
91-
; CHECK: %{{.*}} = bitcast <16 x bfloat> %[[SELECTRES]] to <16 x i16>
105+
; CHECK: %[[COND1:.*]] = fcmp olt <16 x bfloat> %[[SRC0BF]], %[[SRC1BF]]
106+
; CHECK: %[[SELECT1RES:.*]] = select <16 x i1> %[[COND1]], <16 x bfloat> %[[SRC0BF]], <16 x bfloat> %[[SRC1BF]]
107+
; CHECK: %[[COND2:.*]] = fcmp uno <16 x bfloat> %[[SRC0BF]], %[[SRC0BF]]
108+
; CHECK: %[[SELECT2RES:.*]] = select <16 x i1> %[[COND2]], <16 x bfloat> %[[SRC1BF]], <16 x bfloat> %[[SRC0BF]]
109+
; CHECK: %[[COND3:.*]] = fcmp ord <16 x bfloat> %[[SRC0BF]], %[[SRC1BF]]
110+
; CHECK: %[[SELECT3RES:.*]] = select <16 x i1> %[[COND3]], <16 x bfloat> %[[SELECT1RES]], <16 x bfloat> %[[SELECT2RES]]
111+
; CHECK: %{{.*}} = bitcast <16 x bfloat> %[[SELECT3RES]] to <16 x i16>
92112
%call7 = call spir_func <16 x i16> @_Z18__builtin_bf16_minDv16_tS_(<16 x i16> %v1_16, <16 x i16> %v2_16) #2
93113
%arrayidx8 = getelementptr inbounds <16 x i16>, <16 x i16> addrspace(1)* %out16, i64 4
94114
store <16 x i16> %call7, <16 x i16> addrspace(1)* %arrayidx8, align 32
@@ -101,45 +121,65 @@ define spir_kernel void @test_max(i16 addrspace(1)* %out1, i16 zeroext %v1_1, i1
101121
entry:
102122
; CHECK: %[[SRC0BF:.*]] = bitcast i16 %v1_1 to bfloat
103123
; CHECK: %[[SRC1BF:.*]] = bitcast i16 %v2_1 to bfloat
104-
; CHECK: %[[COND:.*]] = fcmp ogt bfloat %[[SRC0BF]], %[[SRC1BF]]
105-
; CHECK: %[[SELECTRES:.*]] = select i1 %[[COND]], bfloat %[[SRC0BF]], bfloat %[[SRC1BF]]
106-
; CHECK: %{{.*}} = bitcast bfloat %[[SELECTRES]] to i16
124+
; CHECK: %[[COND1:.*]] = fcmp ogt bfloat %[[SRC0BF]], %[[SRC1BF]]
125+
; CHECK: %[[SELECT1RES:.*]] = select i1 %[[COND1]], bfloat %[[SRC0BF]], bfloat %[[SRC1BF]]
126+
; CHECK: %[[COND2:.*]] = fcmp uno bfloat %[[SRC0BF]], %[[SRC0BF]]
127+
; CHECK: %[[SELECT2RES:.*]] = select i1 %[[COND2]], bfloat %[[SRC1BF]], bfloat %[[SRC0BF]]
128+
; CHECK: %[[COND3:.*]] = fcmp ord bfloat %[[SRC0BF]], %[[SRC1BF]]
129+
; CHECK: %[[SELECT3RES:.*]] = select i1 %[[COND3]], bfloat %[[SELECT1RES]], bfloat %[[SELECT2RES]]
130+
; CHECK: %{{.*}} = bitcast bfloat %[[SELECT3RES]] to i16
107131
%call = call spir_func zeroext i16 @_Z18__builtin_bf16_maxtt(i16 zeroext %v1_1, i16 zeroext %v2_1) #2
108132
%arrayidx = getelementptr inbounds i16, i16 addrspace(1)* %out1, i64 0
109133
store i16 %call, i16 addrspace(1)* %arrayidx, align 2
110134

111135
; CHECK: %[[SRC0BF:.*]] = bitcast <2 x i16> %v1_2 to <2 x bfloat>
112136
; CHECK: %[[SRC1BF:.*]] = bitcast <2 x i16> %v2_2 to <2 x bfloat>
113-
; CHECK: %[[COND:.*]] = fcmp ogt <2 x bfloat> %[[SRC0BF]], %[[SRC1BF]]
114-
; CHECK: %[[SELECTRES:.*]] = select <2 x i1> %[[COND]], <2 x bfloat> %[[SRC0BF]], <2 x bfloat> %[[SRC1BF]]
115-
; CHECK: %{{.*}} = bitcast <2 x bfloat> %[[SELECTRES]] to <2 x i16>
137+
; CHECK: %[[COND1:.*]] = fcmp ogt <2 x bfloat> %[[SRC0BF]], %[[SRC1BF]]
138+
; CHECK: %[[SELECT1RES:.*]] = select <2 x i1> %[[COND1]], <2 x bfloat> %[[SRC0BF]], <2 x bfloat> %[[SRC1BF]]
139+
; CHECK: %[[COND2:.*]] = fcmp uno <2 x bfloat> %[[SRC0BF]], %[[SRC0BF]]
140+
; CHECK: %[[SELECT2RES:.*]] = select <2 x i1> %[[COND2]], <2 x bfloat> %[[SRC1BF]], <2 x bfloat> %[[SRC0BF]]
141+
; CHECK: %[[COND3:.*]] = fcmp ord <2 x bfloat> %[[SRC0BF]], %[[SRC1BF]]
142+
; CHECK: %[[SELECT3RES:.*]] = select <2 x i1> %[[COND3]], <2 x bfloat> %[[SELECT1RES]], <2 x bfloat> %[[SELECT2RES]]
143+
; CHECK: %{{.*}} = bitcast <2 x bfloat> %[[SELECT3RES]] to <2 x i16>
116144
%call1 = call spir_func <2 x i16> @_Z18__builtin_bf16_maxDv2_tS_(<2 x i16> %v1_2, <2 x i16> %v2_2) #2
117145
%arrayidx2 = getelementptr inbounds <2 x i16>, <2 x i16> addrspace(1)* %out2, i64 1
118146
store <2 x i16> %call1, <2 x i16> addrspace(1)* %arrayidx2, align 4
119147

120148
; CHECK: %[[SRC0BF:.*]] = bitcast <4 x i16> %v1_4 to <4 x bfloat>
121149
; CHECK: %[[SRC1BF:.*]] = bitcast <4 x i16> %v2_4 to <4 x bfloat>
122-
; CHECK: %[[COND:.*]] = fcmp ogt <4 x bfloat> %[[SRC0BF]], %[[SRC1BF]]
123-
; CHECK: %[[SELECTRES:.*]] = select <4 x i1> %[[COND]], <4 x bfloat> %[[SRC0BF]], <4 x bfloat> %[[SRC1BF]]
124-
; CHECK: %{{.*}} = bitcast <4 x bfloat> %[[SELECTRES]] to <4 x i16>
150+
; CHECK: %[[COND1:.*]] = fcmp ogt <4 x bfloat> %[[SRC0BF]], %[[SRC1BF]]
151+
; CHECK: %[[SELECT1RES:.*]] = select <4 x i1> %[[COND1]], <4 x bfloat> %[[SRC0BF]], <4 x bfloat> %[[SRC1BF]]
152+
; CHECK: %[[COND2:.*]] = fcmp uno <4 x bfloat> %[[SRC0BF]], %[[SRC0BF]]
153+
; CHECK: %[[SELECT2RES:.*]] = select <4 x i1> %[[COND2]], <4 x bfloat> %[[SRC1BF]], <4 x bfloat> %[[SRC0BF]]
154+
; CHECK: %[[COND3:.*]] = fcmp ord <4 x bfloat> %[[SRC0BF]], %[[SRC1BF]]
155+
; CHECK: %[[SELECT3RES:.*]] = select <4 x i1> %[[COND3]], <4 x bfloat> %[[SELECT1RES]], <4 x bfloat> %[[SELECT2RES]]
156+
; CHECK: %{{.*}} = bitcast <4 x bfloat> %[[SELECT3RES]] to <4 x i16>
125157
%call3 = call spir_func <4 x i16> @_Z18__builtin_bf16_maxDv4_tS_(<4 x i16> %v1_4, <4 x i16> %v2_4) #2
126158
%arrayidx4 = getelementptr inbounds <4 x i16>, <4 x i16> addrspace(1)* %out4, i64 2
127159
store <4 x i16> %call3, <4 x i16> addrspace(1)* %arrayidx4, align 8
128160

129161
; CHECK: %[[SRC0BF:.*]] = bitcast <8 x i16> %v1_8 to <8 x bfloat>
130162
; CHECK: %[[SRC1BF:.*]] = bitcast <8 x i16> %v2_8 to <8 x bfloat>
131-
; CHECK: %[[COND:.*]] = fcmp ogt <8 x bfloat> %[[SRC0BF]], %[[SRC1BF]]
132-
; CHECK: %[[SELECTRES:.*]] = select <8 x i1> %[[COND]], <8 x bfloat> %[[SRC0BF]], <8 x bfloat> %[[SRC1BF]]
133-
; CHECK: %{{.*}} = bitcast <8 x bfloat> %[[SELECTRES]] to <8 x i16>
163+
; CHECK: %[[COND1:.*]] = fcmp ogt <8 x bfloat> %[[SRC0BF]], %[[SRC1BF]]
164+
; CHECK: %[[SELECT1RES:.*]] = select <8 x i1> %[[COND1]], <8 x bfloat> %[[SRC0BF]], <8 x bfloat> %[[SRC1BF]]
165+
; CHECK: %[[COND2:.*]] = fcmp uno <8 x bfloat> %[[SRC0BF]], %[[SRC0BF]]
166+
; CHECK: %[[SELECT2RES:.*]] = select <8 x i1> %[[COND2]], <8 x bfloat> %[[SRC1BF]], <8 x bfloat> %[[SRC0BF]]
167+
; CHECK: %[[COND3:.*]] = fcmp ord <8 x bfloat> %[[SRC0BF]], %[[SRC1BF]]
168+
; CHECK: %[[SELECT3RES:.*]] = select <8 x i1> %[[COND3]], <8 x bfloat> %[[SELECT1RES]], <8 x bfloat> %[[SELECT2RES]]
169+
; CHECK: %{{.*}} = bitcast <8 x bfloat> %[[SELECT3RES]] to <8 x i16>
134170
%call5 = call spir_func <8 x i16> @_Z18__builtin_bf16_maxDv8_tS_(<8 x i16> %v1_8, <8 x i16> %v2_8) #2
135171
%arrayidx6 = getelementptr inbounds <8 x i16>, <8 x i16> addrspace(1)* %out8, i64 3
136172
store <8 x i16> %call5, <8 x i16> addrspace(1)* %arrayidx6, align 16
137173

138174
; CHECK: %[[SRC0BF:.*]] = bitcast <16 x i16> %v1_16 to <16 x bfloat>
139175
; CHECK: %[[SRC1BF:.*]] = bitcast <16 x i16> %v2_16 to <16 x bfloat>
140-
; CHECK: %[[COND:.*]] = fcmp ogt <16 x bfloat> %[[SRC0BF]], %[[SRC1BF]]
141-
; CHECK: %[[SELECTRES:.*]] = select <16 x i1> %[[COND]], <16 x bfloat> %[[SRC0BF]], <16 x bfloat> %[[SRC1BF]]
142-
; CHECK: %{{.*}} = bitcast <16 x bfloat> %[[SELECTRES]] to <16 x i16>
176+
; CHECK: %[[COND1:.*]] = fcmp ogt <16 x bfloat> %[[SRC0BF]], %[[SRC1BF]]
177+
; CHECK: %[[SELECT1RES:.*]] = select <16 x i1> %[[COND1]], <16 x bfloat> %[[SRC0BF]], <16 x bfloat> %[[SRC1BF]]
178+
; CHECK: %[[COND2:.*]] = fcmp uno <16 x bfloat> %[[SRC0BF]], %[[SRC0BF]]
179+
; CHECK: %[[SELECT2RES:.*]] = select <16 x i1> %[[COND2]], <16 x bfloat> %[[SRC1BF]], <16 x bfloat> %[[SRC0BF]]
180+
; CHECK: %[[COND3:.*]] = fcmp ord <16 x bfloat> %[[SRC0BF]], %[[SRC1BF]]
181+
; CHECK: %[[SELECT3RES:.*]] = select <16 x i1> %[[COND3]], <16 x bfloat> %[[SELECT1RES]], <16 x bfloat> %[[SELECT2RES]]
182+
; CHECK: %{{.*}} = bitcast <16 x bfloat> %[[SELECT3RES]] to <16 x i16>
143183
%call7 = call spir_func <16 x i16> @_Z18__builtin_bf16_maxDv16_tS_(<16 x i16> %v1_16, <16 x i16> %v2_16) #2
144184
%arrayidx8 = getelementptr inbounds <16 x i16>, <16 x i16> addrspace(1)* %out16, i64 4
145185
store <16 x i16> %call7, <16 x i16> addrspace(1)* %arrayidx8, align 32

0 commit comments

Comments
 (0)