Skip to content

Commit 7fa42ec

Browse files
azabaznosys_zuul
authored andcommitted
Indirect regions legalization improvement
Change-Id: Id2a58b95279a8bb1558583e594330d278815ec4c
1 parent 0e27fdc commit 7fa42ec

File tree

6 files changed

+137
-38
lines changed

6 files changed

+137
-38
lines changed

IGC/VectorCompiler/lib/GenXCodeGen/GenXAlignmentInfo.cpp

Lines changed: 73 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
3737
//===----------------------------------------------------------------------===//
3838
#define DEBUG_TYPE "GENX_ALIGNMENT_INFO"
3939

40+
#include "IGC/common/StringMacros.hpp"
41+
4042
#include <algorithm>
4143
#include "GenX.h"
4244
#include "GenXAlignmentInfo.h"
@@ -63,6 +65,7 @@ Alignment AlignmentInfo::get(Value *V)
6365
if (auto C = dyn_cast<Constant>(V))
6466
return Alignment(C);
6567
auto Inst = dyn_cast<Instruction>(V);
68+
6669
if (!Inst) {
6770
// An Argument has unknown alignment.
6871
// (FIXME: We may need to do better than this, tracing the value of the
@@ -141,8 +144,10 @@ Alignment AlignmentInfo::get(Value *V)
141144
Alignment A(0, 0); // assume unknown
142145
if (BinaryOperator *BO = dyn_cast<BinaryOperator>(WorkInst)) {
143146
A = Alignment(); // assume uncomputed
144-
Alignment A0 = getFromInstMap(BO->getOperand(0));
145-
Alignment A1 = getFromInstMap(BO->getOperand(1));
147+
auto *Op0 = BO->getOperand(0);
148+
auto *Op1 = BO->getOperand(1);
149+
Alignment A0 = getFromInstMap(Op0);
150+
Alignment A1 = getFromInstMap(Op1);
146151
if (!A0.isUncomputed() && !A1.isUncomputed()) {
147152
switch (BO->getOpcode()) {
148153
case Instruction::Add:
@@ -164,9 +169,17 @@ Alignment AlignmentInfo::get(Value *V)
164169
} else
165170
A = Alignment::getUnknown();
166171
break;
167-
default:
168-
A = Alignment::getUnknown();
169-
break;
172+
case Instruction::And:
173+
if (auto *CI0 = dyn_cast<ConstantInt>(Op0)) {
174+
A = A1.logicalAnd(CI0);
175+
} else if (auto *CI1 = dyn_cast<ConstantInt>(Op1)) {
176+
A = A0.logicalAnd(CI1);
177+
} else
178+
A = Alignment::getUnknown();
179+
break;
180+
default:
181+
A = Alignment::getUnknown();
182+
break;
170183
}
171184
}
172185
} else if (CastInst *CI = dyn_cast<CastInst>(WorkInst)) {
@@ -201,11 +214,11 @@ Alignment AlignmentInfo::get(Value *V)
201214
switch (GenXIntrinsic::getGenXIntrinsicID(WorkInst)) {
202215
case GenXIntrinsic::genx_rdregioni:
203216
case GenXIntrinsic::genx_rdregionf: {
204-
// Handle the case of reading a scalar from element 0 of a vector, as
217+
// Handle the case of reading a scalar from element of a vector, as
205218
// a trunc from i32 to i16 is lowered to a bitcast to v2i16 then a
206219
// rdregion.
207220
Region R(WorkInst, BaleInfo());
208-
if (!R.Indirect && !R.Offset)
221+
if (!R.Indirect && (R.NumElements == 1))
209222
A = getFromInstMap(WorkInst->getOperand(0));
210223
else
211224
A = Alignment(0, 0);
@@ -275,7 +288,30 @@ Alignment::Alignment(unsigned C)
275288
{
276289
LogAlign = countTrailingZeros(C);
277290
ExtraBits = 0;
278-
ConstBits = (C < 0x7fffffff)? C : 0x7fffffff;
291+
ConstBits = (C < MaskForUnknown) ? C : MaskForUnknown;
292+
}
293+
294+
Alignment Alignment::getAlignmentForConstant(Constant *C) {
295+
IGC_ASSERT(!isa<VectorType>(C->getType()));
296+
Alignment A;
297+
A.setUncomputed();
298+
if (isa<UndefValue>(C)) {
299+
A.LogAlign = 31;
300+
A.ExtraBits = 0;
301+
A.ConstBits = MaskForUnknown;
302+
} else if (auto CI = dyn_cast<ConstantInt>(C)) {
303+
int64_t SVal = CI->getSExtValue();
304+
// Get least significant bits to count LogAlign
305+
unsigned LSBBits = SVal & UnsignedAllOnes;
306+
A.LogAlign = countTrailingZeros(LSBBits);
307+
308+
A.ExtraBits = 0;
309+
A.ConstBits = MaskForUnknown;
310+
if (SVal < MaskForUnknown && SVal >= 0 &&
311+
SVal <= std::numeric_limits<unsigned>::max())
312+
A.ConstBits = static_cast<unsigned>(SVal);
313+
}
314+
return A;
279315
}
280316

281317
/***********************************************************************
@@ -284,19 +320,18 @@ Alignment::Alignment(unsigned C)
284320
Alignment::Alignment(Constant *C)
285321
{
286322
setUncomputed();
287-
if (isa<VectorType>(C->getType()))
288-
C = C->getAggregateElement(0U);
289-
if (isa<UndefValue>(C)) {
290-
LogAlign = 31;
291-
ExtraBits = 0;
292-
ConstBits = 0x7fffffff;
293-
} else if (auto CI = dyn_cast<ConstantInt>(C)) {
294-
LogAlign = countTrailingZeros((unsigned)(CI->getSExtValue()));
295-
ExtraBits = 0;
296-
ConstBits = 0x7fffffff;
297-
if (CI->getSExtValue() < 0x7fffffff && CI->getSExtValue() >= 0)
298-
ConstBits = (unsigned)(CI->getSExtValue());
323+
if (auto *VT = dyn_cast<VectorType>(C->getType())) {
324+
// Take splat if exists
325+
if (auto *SplatVal = C->getSplatValue())
326+
C = SplatVal;
327+
else {
328+
// Otherwise be conservative and pretend alignment
329+
// unknown for non-splat vectors
330+
*this = Alignment::getUnknown();
331+
return;
332+
}
299333
}
334+
*this = getAlignmentForConstant(C);
300335
}
301336

302337
/***********************************************************************
@@ -365,6 +400,24 @@ Alignment Alignment::mul(Alignment Other) const
365400
return Alignment(MinLogAlign, ExtraBits2 & ((1 << MinLogAlign) - 1));
366401
}
367402

403+
/***********************************************************************
404+
* logicalAnd : logical and two alignments. Only constant int supported.
405+
*/
406+
Alignment Alignment::logicalAnd(ConstantInt *CI) const {
407+
IGC_ASSERT(!isUncomputed() && CI);
408+
// If value doesn't fit into unsigned then be conservative and pretend
409+
// that alignement is unknown
410+
int64_t Val = CI->getSExtValue();
411+
if (Val < std::numeric_limits<int>::min() ||
412+
Val > std::numeric_limits<int>::max())
413+
return Alignment::getUnknown();
414+
unsigned UVal = static_cast<unsigned>(std::abs(Val));
415+
unsigned ValLSB = countTrailingZeros(UVal, ZB_Width);
416+
// Chop off constant bits according to maximum log align
417+
unsigned NewLogAlign = std::max(ValLSB, LogAlign);
418+
return Alignment(NewLogAlign, UVal & ((1 << NewLogAlign) - 1));
419+
}
420+
368421
/***********************************************************************
369422
* getFromInstMap : get the alignment of a value, direct from InstMap if
370423
* found else return Unknown, Alignment(0, 0)

IGC/VectorCompiler/lib/GenXCodeGen/GenXAlignmentInfo.h

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -71,12 +71,16 @@ class Alignment {
7171
unsigned LogAlign;
7272
unsigned ExtraBits;
7373
unsigned ConstBits;
74+
75+
static constexpr unsigned MaskForUnknown = 0x7fffffff;
76+
static constexpr unsigned UnsignedAllOnes = 0xffffffff;
77+
7478
public:
7579
// No-arg constructor sets to uncomputed state.
7680
Alignment() { setUncomputed(); }
7781
// Constructor given LogAlign and ExtraBits fields.
7882
Alignment(unsigned LogAlign, unsigned ExtraBits)
79-
: LogAlign(LogAlign), ExtraBits(ExtraBits), ConstBits(0x7fffffff) {}
83+
: LogAlign(LogAlign), ExtraBits(ExtraBits), ConstBits(MaskForUnknown) {}
8084
// Constructor given literal value.
8185
Alignment(unsigned C);
8286
// Constructor given Constant.
@@ -103,11 +107,17 @@ class Alignment {
103107
Alignment add(Alignment Other) const;
104108
// Mul one Alignment with another Alignment
105109
Alignment mul(Alignment Other) const;
110+
// Logical and Alignment with constant integer
111+
Alignment logicalAnd(ConstantInt *CI) const;
106112

107113
// accessors
108-
bool isUncomputed() const { return LogAlign == 0xffffffff; }
109-
bool isUnknown() const { return LogAlign == 0 && ConstBits == 0x7fffffff; }
110-
bool isConstant() const { return !isUncomputed() && ConstBits != 0x7fffffff; }
114+
bool isUncomputed() const { return LogAlign == UnsignedAllOnes; }
115+
bool isUnknown() const {
116+
return LogAlign == 0 && ConstBits == MaskForUnknown;
117+
}
118+
bool isConstant() const {
119+
return !isUncomputed() && ConstBits != MaskForUnknown;
120+
}
111121
unsigned getLogAlign() const { IGC_ASSERT(!isUncomputed()); return LogAlign; }
112122
unsigned getExtraBits() const { IGC_ASSERT(!isUncomputed()); return ExtraBits; }
113123
int64_t getConstBits() const { IGC_ASSERT(isConstant()); return ConstBits; }
@@ -117,14 +127,15 @@ class Alignment {
117127
ExtraBits == Rhs.ExtraBits &&
118128
ConstBits == Rhs.ConstBits);
119129
}
130+
static Alignment getAlignmentForConstant(Constant *C);
120131
// Debug dump/print
121132
void dump() const;
122133
void print(raw_ostream &OS) const;
123134
private:
124135
void setUncomputed() {
125-
LogAlign = 0xffffffff;
136+
LogAlign = UnsignedAllOnes;
126137
ExtraBits = 0;
127-
ConstBits = 0x7fffffff;
138+
ConstBits = MaskForUnknown;
128139
}
129140
};
130141

IGC/VectorCompiler/lib/GenXCodeGen/GenXRegion.cpp

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -361,13 +361,8 @@ unsigned Region::getLegalSize(unsigned Idx, bool Allow2D,
361361
// operand or is a 1D source operand (so GenXCisaBuilder can turn it
362362
// into Nx1 instead of 1xN). We use Allow2D as a proxy for "is source
363363
// operand".
364-
unsigned GRFsPerIndirect = 1;
365-
IGC_ASSERT(ST);
366-
if (ST->hasIndirectGRFCrossing() &&
367-
// SKL+. See if we can allow GRF crossing.
368-
(Allow2D || !is2D())) {
369-
GRFsPerIndirect = 2;
370-
}
364+
unsigned GRFsPerIndirect =
365+
genx::getNumGRFsPerIndirectForRegion(*this, ST, Allow2D);
371366
unsigned Last = (NumElements / Width - 1) * VStride + (Width - 1) * Stride;
372367
unsigned Max = InputNumElements - Last - 1 + RealIdx;
373368
unsigned Min = RealIdx;
@@ -378,11 +373,11 @@ unsigned Region::getLegalSize(unsigned Idx, bool Allow2D,
378373
else if (MinMaxGRFDiff == 1 && GRFsPerIndirect > 1)
379374
ElementsToBoundary = ElementsPerGRF - (Max & (ElementsPerGRF - 1));
380375
// We may be able to refine an indirect region legal width further...
381-
if (exactLog2(ParentWidth) >= 0
382-
&& ParentWidth <= ElementsPerGRF) {
383-
// ParentWidth tells us that a row of our region cannot cross a GRF
384-
// boundary. Say that the boundary is at the next multiple of
385-
// ParentWidth.
376+
if (exactLog2(ParentWidth) >= 0 &&
377+
ParentWidth <= GRFsPerIndirect * ElementsPerGRF) {
378+
// ParentWidth tells us that a row of our region cannot cross a
379+
// possible number of elements addressed by indirect region. Say that
380+
// the boundary is at the next multiple of ParentWidth.
386381
ElementsToBoundary = std::max(ParentWidth - RealIdx % ParentWidth,
387382
ElementsToBoundary);
388383
} else if (!isa<VectorType>(Indirect->getType())) {

IGC/VectorCompiler/lib/GenXCodeGen/GenXRegionCollapsing.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ class GenXRegionCollapsing : public FunctionPass {
113113
Value *insertOp(Instruction::BinaryOps Opcode, Value *Lhs, Value *Rhs,
114114
const Twine &Name, Instruction *InsertBefore,
115115
const DebugLoc &DL);
116+
bool isSingleElementRdRExtract(Instruction *I);
116117
};
117118

118119
}// end namespace llvm
@@ -579,6 +580,21 @@ void GenXRegionCollapsing::processRdRegion(Instruction *InnerRd)
579580
LLVM_DEBUG(dbgs() << "Cannot normalize element type\n");
580581
return;
581582
}
583+
584+
// If it's a signle element extract from an indirect region
585+
// then check if there exist some other extracts
586+
if (OuterR.Indirect && (OuterR.NumElements != 1) &&
587+
isSingleElementRdRExtract(InnerRd)) {
588+
auto NumExtracts = llvm::count_if(OuterRd->uses(), [this](Use &U) {
589+
return isSingleElementRdRExtract(cast<Instruction>(U.getUser()));
590+
});
591+
// If there are some more extracts except this one (InnerRd)
592+
// then not combine these regions to prevent generation
593+
// of extra address conversions for a combined region
594+
if (NumExtracts > 1)
595+
return;
596+
}
597+
582598
Region CombinedR;
583599
if (!combineRegions(&OuterR, &InnerR, &CombinedR))
584600
return; // cannot combine
@@ -1463,3 +1479,9 @@ Value *GenXRegionCollapsing::insertOp(Instruction::BinaryOps Opcode, Value *Lhs,
14631479
return Inst;
14641480
}
14651481

1482+
bool GenXRegionCollapsing::isSingleElementRdRExtract(Instruction *I) {
1483+
if (!GenXIntrinsic::isRdRegion(I))
1484+
return false;
1485+
Region R = Region::getWithOffset(I, /*WantParentWidth=*/true);
1486+
return R.NumElements == 1 && !R.Indirect;
1487+
}

IGC/VectorCompiler/lib/GenXCodeGen/GenXUtil.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@ SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
4949
#include "llvm/IR/Metadata.h"
5050
#include "llvm/IR/Module.h"
5151

52+
#include "llvmWrapper/IR/InstrTypes.h"
53+
5254
#include "Probe/Assertion.h"
5355
#include <iterator>
5456

@@ -1805,3 +1807,16 @@ bool genx::breakConstantExprs(Function *F) {
18051807
}
18061808
return Modified;
18071809
}
1810+
1811+
unsigned genx::getNumGRFsPerIndirectForRegion(const genx::Region &R,
1812+
const GenXSubtarget *ST,
1813+
bool Allow2D) {
1814+
IGC_ASSERT(R.Indirect && "Indirect region expected");
1815+
IGC_ASSERT(ST);
1816+
if (ST->hasIndirectGRFCrossing() &&
1817+
// SKL+. See if we can allow GRF crossing.
1818+
(Allow2D || !R.is2D())) {
1819+
return 2;
1820+
}
1821+
return 1;
1822+
}

IGC/VectorCompiler/lib/GenXCodeGen/GenXUtil.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -524,6 +524,9 @@ bool breakConstantExprs(Instruction *I);
524524
// breakConstantExprs : break constant expressions in function F.
525525
// Return true if any modifications have been made, false otherwise.
526526
bool breakConstantExprs(Function *F);
527+
// Get possible number of GRFs for indirect region
528+
unsigned getNumGRFsPerIndirectForRegion(const genx::Region &R,
529+
const GenXSubtarget *ST, bool Allow2D);
527530

528531
// BinaryDataAccumulator: it's a helper class to accumulate binary data
529532
// in one buffer.

0 commit comments

Comments
 (0)