Skip to content

Commit 34c8a74

Browse files
Sharma, Rithikigcbot
authored andcommitted
Add QuadBroadcast optimization for WaveShuffleIndex pattern
Adds a new QuadBroadcast intrinsic for WaveShuffleIndex operation.
1 parent 26a57ba commit 34c8a74

File tree

6 files changed

+281
-52
lines changed

6 files changed

+281
-52
lines changed

IGC/Compiler/CISACodeGen/EmitVISAPass.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6051,6 +6051,24 @@ void EmitPass::emitSimdClusteredBroadcast(llvm::Instruction* inst)
60516051

60526052
}
60536053

6054+
void EmitPass::emitQuadBroadcast(llvm::Instruction* inst) {
6055+
CVariable* data = GetSymbol(inst->getOperand(0));
6056+
ConstantInt* laneOp = dyn_cast<ConstantInt>(inst->getOperand(1));
6057+
IGC_ASSERT(laneOp && laneOp->getZExtValue() < 4);
6058+
6059+
if (data->IsUniform()) {
6060+
m_encoder->Copy(m_destination, data);
6061+
m_encoder->Push();
6062+
return;
6063+
}
6064+
6065+
m_encoder->SetNoMask();
6066+
m_encoder->SetSrcRegion(0, 4, 4, 0);
6067+
m_encoder->SetSrcSubReg(0, laneOp->getZExtValue());
6068+
m_encoder->Copy(m_destination, data);
6069+
m_encoder->Push();
6070+
}
6071+
60546072
void EmitPass::emitSimdShuffleDown(llvm::Instruction* inst)
60556073
{
60566074
CVariable* pCurrentData = GetSymbol(inst->getOperand(0));
@@ -9476,6 +9494,9 @@ void EmitPass::EmitGenIntrinsicMessage(llvm::GenIntrinsicInst* inst)
94769494
case GenISAIntrinsic::GenISA_WaveBroadcast:
94779495
emitSimdShuffle(inst);
94789496
break;
9497+
case GenISAIntrinsic::GenISA_QuadBroadcast:
9498+
emitQuadBroadcast(inst);
9499+
break;
94799500
case GenISAIntrinsic::GenISA_WaveClusteredBroadcast:
94809501
emitSimdClusteredBroadcast(inst);
94819502
break;

IGC/Compiler/CISACodeGen/EmitVISAPass.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,7 @@ class EmitPass : public llvm::FunctionPass
249249
void emitSimdSize(llvm::Instruction* inst);
250250
void emitSimdShuffle(llvm::Instruction* inst);
251251
void emitSimdClusteredBroadcast(llvm::Instruction* inst);
252+
void emitQuadBroadcast(llvm::Instruction* inst);
252253
void emitCrossInstanceMov(const SSource& source, const DstModifier& modifier);
253254
void emitSimdShuffleDown(llvm::Instruction* inst);
254255
void emitSimdShuffleXor(llvm::Instruction* inst);

IGC/Compiler/CISACodeGen/opCode.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,7 @@ DECLARE_OPCODE(GenISA_WavePrefix, GenISAIntrinsic, llvm_wavePrefix, false, false
292292
DECLARE_OPCODE(GenISA_QuadPrefix, GenISAIntrinsic, llvm_quadPrefix, false, false, false, false, false, false, false)
293293
DECLARE_OPCODE(GenISA_WaveClusteredPrefix, GenISAIntrinsic, llvm_waveClusteredPrefix, false, false, false, false, false, false, false)
294294
DECLARE_OPCODE(GenISA_WaveShuffleIndex, GenISAIntrinsic, llvm_waveShuffleIndex, false, false, false, false, false, false, false)
295+
DECLARE_OPCODE(GenISA_QuadBroadcast, GenISAIntrinsic, llvm_QuadBroadcast, false, false, false, false, false, false, false)
295296
DECLARE_OPCODE(GenISA_WaveBroadcast, GenISAIntrinsic, llvm_waveBroadcast, false, false, false, false, false, false, false)
296297
DECLARE_OPCODE(GenISA_WaveClusteredBroadcast, GenISAIntrinsic, llvm_waveClusteredBroadcast, false, false, false, false, false, false, false)
297298

IGC/Compiler/CustomSafeOptPass.cpp

Lines changed: 112 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -289,8 +289,7 @@ void CustomSafeOptPass::visitAnd(BinaryOperator& I) {
289289
// also be written manually as
290290
// uint32_t other_id = sg.get_local_id() ^ XOR_VALUE;
291291
// r = select_from_group(sg, x, other_id);
292-
void CustomSafeOptPass::visitShuffleIndex(llvm::CallInst* I)
293-
{
292+
void CustomSafeOptPass::visitShuffleIndex(llvm::CallInst* I) {
294293
using namespace llvm::PatternMatch;
295294
/*
296295
Pattern match
@@ -299,87 +298,148 @@ void CustomSafeOptPass::visitShuffleIndex(llvm::CallInst* I)
299298
%xor = xor i16 %[optional1], 1
300299
...[optional2] = %xor
301300
%simdShuffle = call i32 @llvm.genx.GenISA.WaveShuffleIndex.i32(i32 %x, i32 %[optional2], i32 0)
302-
303-
Optional can be any combinations of :
301+
Optional can be any combinations of:
304302
* %and = and i16 %856, 63
305303
* %zext = zext i16 %857 to i32
306304
We ignore any combinations of those, as they don't change the final calculated value,
307305
and different permutations were observed.
308306
*/
309307

308+
auto getInstructionIgnoringAndZext = [](Value* V, unsigned Opcode) -> Instruction* {
309+
while (auto* VI = dyn_cast<Instruction>(V)) {
310+
if (VI->getOpcode() == Opcode) {
311+
return VI;
312+
}
313+
else if (auto* ZI = dyn_cast<ZExtInst>(VI)) {
314+
// Check if zext is from i16 to i32
315+
if (ZI->getSrcTy()->isIntegerTy(16) && ZI->getDestTy()->isIntegerTy(32)) {
316+
V = ZI->getOperand(0); // Skip over zext
317+
}
318+
else {
319+
return nullptr; // Not the zext we are looking for
320+
}
321+
}
322+
else if (VI->getOpcode() == Instruction::And) {
323+
ConstantInt* andValueConstant = dyn_cast<ConstantInt>(VI->getOperand(1));
324+
// We handle "redundant values", so those which bits enable all of
325+
// 32 lanes, so 31, 63 (spotted in nature), 127, 255 etc.
326+
if (andValueConstant && ((andValueConstant->getZExtValue() & 31) != 31)) {
327+
return nullptr;
328+
}
329+
V = VI->getOperand(0); // Skip over and
330+
}
331+
else {
332+
return nullptr; // Not a zext, and, or the specified opcode
333+
}
334+
}
335+
return nullptr; //unreachable
336+
};
337+
338+
Value* indexOp = I->getOperand(1);
339+
340+
// Get helper lanes parameter
310341
ConstantInt* enableHelperLanes = dyn_cast<ConstantInt>(I->getOperand(2));
311-
if (!enableHelperLanes || enableHelperLanes->getZExtValue() != 0) {
342+
if (!enableHelperLanes) {
312343
return;
313344
}
314345

315-
auto getInstructionIgnoringAndZext = []( Value* V, unsigned Opcode ) -> Instruction* {
316-
while( auto* VI = dyn_cast<Instruction>( V ) ) {
317-
if( VI->getOpcode() == Opcode ) {
318-
return VI;
319-
}
320-
else if( auto* ZI = dyn_cast<ZExtInst>( VI ) ) {
321-
// Check if zext is from i16 to i32
322-
if( ZI->getSrcTy()->isIntegerTy( 16 ) && ZI->getDestTy()->isIntegerTy( 32 ) ) {
323-
V = ZI->getOperand( 0 ); // Skip over zext
324-
} else {
325-
return nullptr; // Not the zext we are looking for
346+
// Try QuadBroadcast pattern if helper lanes = 1
347+
if (enableHelperLanes->getZExtValue() == 1) {
348+
auto* zextInst = dyn_cast<ZExtInst>(indexOp);
349+
if (zextInst && zextInst->getSrcTy()->isIntegerTy(16) &&
350+
zextInst->getDestTy()->isIntegerTy(32)) {
351+
352+
auto* andInst = dyn_cast<Instruction>(zextInst->getOperand(0));
353+
if (andInst && andInst->getOpcode() == Instruction::And) {
354+
// Check for mask constant -4 (0xFFFC)
355+
auto* mask = dyn_cast<ConstantInt>(andInst->getOperand(1));
356+
if (mask && mask->getSExtValue() == -4) {
357+
uint32_t laneIdx = 0;
358+
Value* simdLaneOp = andInst->getOperand(0);
359+
360+
// Check for or operation
361+
if (auto* orInst = dyn_cast<Instruction>(simdLaneOp)) {
362+
if (orInst->getOpcode() == Instruction::Or) {
363+
auto* constOffset = dyn_cast<ConstantInt>(orInst->getOperand(1));
364+
// Return if OR value is not a constant or is >= 4
365+
if (!constOffset || constOffset->getZExtValue() >= 4) {
366+
return;
367+
}
368+
laneIdx = constOffset->getZExtValue() & 0x3;
369+
simdLaneOp = orInst->getOperand(0);
370+
}
326371
}
327-
}
328-
else if( VI->getOpcode() == Instruction::And ) {
329-
ConstantInt* andValueConstant = dyn_cast<ConstantInt>( VI->getOperand( 1 ) );
330-
// We handle "redundant values", so those which bits enable all of
331-
// 32 lanes, so 31, 63 (spotted in nature), 127, 255 etc.
332-
if( andValueConstant && (( andValueConstant->getZExtValue() & 31 ) != 31 ) ) {
333-
return nullptr;
372+
373+
// Check for simdLaneId
374+
auto* simdLaneCall = dyn_cast<CallInst>(simdLaneOp);
375+
if (simdLaneCall) {
376+
Function* simdIdF = simdLaneCall->getCalledFunction();
377+
if (simdIdF &&
378+
GenISAIntrinsic::getIntrinsicID(simdIdF) == GenISAIntrinsic::GenISA_simdLaneId) {
379+
380+
// Pattern matched - create QuadBroadcast
381+
IRBuilder<> builder(I);
382+
383+
Function* quadBroadcastFunc = GenISAIntrinsic::getDeclaration(
384+
builder.GetInsertBlock()->getParent()->getParent(),
385+
GenISAIntrinsic::GenISA_QuadBroadcast,
386+
I->getType());
387+
388+
Value* result = builder.CreateCall(quadBroadcastFunc,
389+
{ I->getOperand(0), builder.getInt32(laneIdx) },
390+
"quadBroadcast");
391+
392+
I->replaceAllUsesWith(result);
393+
I->eraseFromParent();
394+
return;
395+
}
334396
}
335-
V = VI->getOperand( 0 ); // Skip over and
336-
} else {
337-
return nullptr; // Not a zext, and, or the specified opcode
338397
}
339398
}
340-
return nullptr; //unreachable
341-
};
399+
}
400+
}
401+
402+
// Try ShuffleXor pattern if helper lanes = 0
403+
if (enableHelperLanes->getZExtValue() != 0) {
404+
return;
405+
}
342406

343-
Instruction* xorInst = getInstructionIgnoringAndZext( I->getOperand( 1 ), Instruction::Xor );
344-
if( !xorInst )
407+
Instruction* xorInst = getInstructionIgnoringAndZext(indexOp, Instruction::Xor);
408+
if (!xorInst)
345409
return;
346410

347-
auto xorOperand = xorInst->getOperand( 0 );
348-
auto xorValueConstant = dyn_cast<ConstantInt> ( xorInst->getOperand( 1 ) );
349-
if( !xorValueConstant )
411+
auto xorOperand = xorInst->getOperand(0);
412+
auto xorValueConstant = dyn_cast<ConstantInt>(xorInst->getOperand(1));
413+
if (!xorValueConstant)
350414
return;
351415

352416
uint64_t xorValue = xorValueConstant->getZExtValue();
353-
if( xorValue >= 16 )
354-
{
417+
if (xorValue >= 16) {
355418
// currently not supported in the emitter
356419
return;
357420
}
358421

359-
auto simdLaneCandidate = getInstructionIgnoringAndZext( xorOperand, Instruction::Call );
360-
422+
auto simdLaneCandidate = getInstructionIgnoringAndZext(xorOperand, Instruction::Call);
361423
if (!simdLaneCandidate)
362424
return;
363425

364-
CallInst* CI = cast<CallInst>( simdLaneCandidate );
426+
CallInst* CI = cast<CallInst>(simdLaneCandidate);
365427
Function* simdIdF = CI->getCalledFunction();
366-
if( !simdIdF || GenISAIntrinsic::getIntrinsicID( simdIdF ) != GenISAIntrinsic::GenISA_simdLaneId)
428+
if (!simdIdF || GenISAIntrinsic::getIntrinsicID(simdIdF) != GenISAIntrinsic::GenISA_simdLaneId)
367429
return;
368430

369-
// since we didn't return earlier, pattern is found
370-
431+
// ShuffleXor pattern found
371432
auto insertShuffleXor = [](IRBuilder<>& builder,
372-
Value* value,
373-
uint32_t xorValue)
374-
{
375-
Function* simdShuffleXorFunc = GenISAIntrinsic::getDeclaration(
376-
builder.GetInsertBlock()->getParent()->getParent(),
377-
GenISAIntrinsic::GenISA_simdShuffleXor,
378-
value->getType());
379-
380-
return builder.CreateCall(simdShuffleXorFunc,
381-
{ value, builder.getInt32(xorValue) }, "simdShuffleXor");
382-
};
433+
Value* value,
434+
uint32_t xorValue) {
435+
Function* simdShuffleXorFunc = GenISAIntrinsic::getDeclaration(
436+
builder.GetInsertBlock()->getParent()->getParent(),
437+
GenISAIntrinsic::GenISA_simdShuffleXor,
438+
value->getType());
439+
440+
return builder.CreateCall(simdShuffleXorFunc,
441+
{ value, builder.getInt32(xorValue) }, "simdShuffleXor");
442+
};
383443

384444
Value* value = I->getOperand(0);
385445
IRBuilder<> builder(I);
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
;=========================== begin_copyright_notice ============================
2+
;
3+
; Copyright (C) 2017-2022 Intel Corporation
4+
;
5+
; SPDX-License-Identifier: MIT
6+
;
7+
;============================ end_copyright_notice =============================
8+
9+
; RUN: igc_opt -igc-custom-safe-opt -S %s -o %t.ll
10+
; RUN: FileCheck %s --input-file=%t.ll
11+
12+
declare i16 @llvm.genx.GenISA.simdLaneId()
13+
declare float @llvm.genx.GenISA.WaveShuffleIndex.f32(float, i32, i32)
14+
declare float @llvm.genx.GenISA.QuadBroadcast.f32(float, i32)
15+
16+
; Test basic quad broadcast pattern for lane 0
17+
; CHECK-LABEL: @test_quad_broadcast_lane0
18+
define float @test_quad_broadcast_lane0(float %x) nounwind {
19+
entry:
20+
%lane = call i16 @llvm.genx.GenISA.simdLaneId()
21+
%masked = and i16 %lane, -4 ; Mask to quad boundary (0xFFFC)
22+
%idx = zext i16 %masked to i32
23+
; CHECK: call float @llvm.genx.GenISA.QuadBroadcast.f32(float %x, i32 0)
24+
%result = call float @llvm.genx.GenISA.WaveShuffleIndex.f32(float %x, i32 %idx, i32 1)
25+
ret float %result
26+
}
27+
28+
; Test basic quad broadcast pattern for lane 1
29+
; CHECK-LABEL: @test_quad_broadcast_lane1
30+
define float @test_quad_broadcast_lane1(float %x) nounwind {
31+
entry:
32+
%lane = call i16 @llvm.genx.GenISA.simdLaneId()
33+
%lane1 = or i16 %lane, 1 ; Set bit for lane 1
34+
%masked = and i16 %lane1, -4 ; Mask to quad boundary
35+
%idx = zext i16 %masked to i32
36+
; CHECK: call float @llvm.genx.GenISA.QuadBroadcast.f32(float %x, i32 1)
37+
%result = call float @llvm.genx.GenISA.WaveShuffleIndex.f32(float %x, i32 %idx, i32 1)
38+
ret float %result
39+
}
40+
41+
; Test basic quad broadcast pattern for lane 2
42+
; CHECK-LABEL: @test_quad_broadcast_lane2
43+
define float @test_quad_broadcast_lane2(float %x) nounwind {
44+
entry:
45+
%lane = call i16 @llvm.genx.GenISA.simdLaneId()
46+
%lane2 = or i16 %lane, 2 ; Set bit for lane 2
47+
%masked = and i16 %lane2, -4 ; Mask to quad boundary
48+
%idx = zext i16 %masked to i32
49+
; CHECK: call float @llvm.genx.GenISA.QuadBroadcast.f32(float %x, i32 2)
50+
%result = call float @llvm.genx.GenISA.WaveShuffleIndex.f32(float %x, i32 %idx, i32 1)
51+
ret float %result
52+
}
53+
54+
; Test basic quad broadcast pattern for lane 3
55+
; CHECK-LABEL: @test_quad_broadcast_lane3
56+
define float @test_quad_broadcast_lane3(float %x) nounwind {
57+
entry:
58+
%lane = call i16 @llvm.genx.GenISA.simdLaneId()
59+
%lane3 = or i16 %lane, 3 ; Set bit for lane 3
60+
%masked = and i16 %lane3, -4 ; Mask to quad boundary
61+
%idx = zext i16 %masked to i32
62+
; CHECK: call float @llvm.genx.GenISA.QuadBroadcast.f32(float %x, i32 3)
63+
%result = call float @llvm.genx.GenISA.WaveShuffleIndex.f32(float %x, i32 %idx, i32 1)
64+
ret float %result
65+
}
66+
67+
; Test that we don't transform when helper lanes = 0
68+
; CHECK-LABEL: @test_no_transform_helper_lanes
69+
define float @test_no_transform_helper_lanes(float %x) nounwind {
70+
entry:
71+
%lane = call i16 @llvm.genx.GenISA.simdLaneId()
72+
%masked = and i16 %lane, -4
73+
%idx = zext i16 %masked to i32
74+
; CHECK: call float @llvm.genx.GenISA.WaveShuffleIndex.f32(float %x, i32 %idx, i32 0)
75+
%result = call float @llvm.genx.GenISA.WaveShuffleIndex.f32(float %x, i32 %idx, i32 0)
76+
ret float %result
77+
}
78+
79+
; Test that we don't transform when using different AND mask
80+
; CHECK-LABEL: @test_no_transform_different_mask
81+
define float @test_no_transform_different_mask(float %x) nounwind {
82+
entry:
83+
%lane = call i16 @llvm.genx.GenISA.simdLaneId()
84+
%masked = and i16 %lane, -8 ; Different mask, not -4 (0xFFFC)
85+
%idx = zext i16 %masked to i32
86+
; CHECK: call float @llvm.genx.GenISA.WaveShuffleIndex.f32(float %x, i32 %idx, i32 1)
87+
%result = call float @llvm.genx.GenISA.WaveShuffleIndex.f32(float %x, i32 %idx, i32 1)
88+
ret float %result
89+
}
90+
91+
; Test that we don't transform when OR constant is too large
92+
; CHECK-LABEL: @test_no_transform_large_lane
93+
define float @test_no_transform_large_lane(float %x) nounwind {
94+
entry:
95+
%lane = call i16 @llvm.genx.GenISA.simdLaneId()
96+
%lane4 = or i16 %lane, 4 ; Invalid quad lane (must be 0-3)
97+
%masked = and i16 %lane4, -4 ; Mask to quad boundary
98+
%idx = zext i16 %masked to i32
99+
; CHECK: call float @llvm.genx.GenISA.WaveShuffleIndex.f32(float %x, i32 %idx, i32 1)
100+
%result = call float @llvm.genx.GenISA.WaveShuffleIndex.f32(float %x, i32 %idx, i32 1)
101+
ret float %result
102+
}
103+
104+
; Test that we don't transform when OR uses non-constant value
105+
; CHECK-LABEL: @test_no_transform_variable_lane
106+
define float @test_no_transform_variable_lane(float %x, i16 %lane_val) nounwind {
107+
entry:
108+
%lane = call i16 @llvm.genx.GenISA.simdLaneId()
109+
%laneN = or i16 %lane, %lane_val ; Variable lane index
110+
%masked = and i16 %laneN, -4 ; Mask to quad boundary
111+
%idx = zext i16 %masked to i32
112+
; CHECK: call float @llvm.genx.GenISA.WaveShuffleIndex.f32(float %x, i32 %idx, i32 1)
113+
%result = call float @llvm.genx.GenISA.WaveShuffleIndex.f32(float %x, i32 %idx, i32 1)
114+
ret float %result
115+
}
116+
117+
; Test that we don't transform valid quad pattern when helper_lanes = 0
118+
; CHECK-LABEL: @test_no_transform_valid_lane_wrong_helper
119+
define float @test_no_transform_valid_lane_wrong_helper(float %x) nounwind {
120+
entry:
121+
%lane = call i16 @llvm.genx.GenISA.simdLaneId()
122+
%lane1 = or i16 %lane, 1 ; Valid lane (1)
123+
%masked = and i16 %lane1, -4 ; Correct mask
124+
%idx = zext i16 %masked to i32
125+
; CHECK: call float @llvm.genx.GenISA.WaveShuffleIndex.f32(float %x, i32 %idx, i32 0)
126+
%result = call float @llvm.genx.GenISA.WaveShuffleIndex.f32(float %x, i32 %idx, i32 0)
127+
ret float %result
128+
}

0 commit comments

Comments
 (0)