Skip to content

Commit ee57323

Browse files
bcheng0127igcbot
authored andcommitted
SWSB: track WAR A0 register dependence
Current solution is too conservative, which may affect performance because setting A@1.
1 parent 93ce5eb commit ee57323

File tree

3 files changed

+82
-50
lines changed

3 files changed

+82
-50
lines changed

visa/HWCaps.inc

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -464,9 +464,13 @@ uint32_t getNumAddrRegisters() const {
464464
return 16;
465465
}
466466

467-
uint32_t getGRFNumOfAddrRegisters() const {
468-
// The element size of address register is 16 bits
469-
return ((getNumAddrRegisters() * 2 - 1) / numEltPerGRF<Type_UB>() + 1);
467+
// each address register is 16 bits
468+
uint32_t getNumAddrRegistersInGRFSizeSWSB() const {
469+
if (hasThreeALUPipes() || hasFourALUPipes()) {
470+
return ((16 * G4_WSIZE) + numEltPerGRF<Type_UB>() - 1) / numEltPerGRF<Type_UB>();
471+
} else {
472+
return 0;
473+
}
470474
}
471475

472476
uint32_t getNumScalarRegisters(void) {
@@ -744,7 +748,7 @@ bool hasWriteCombine() const {
744748
}
745749

746750
bool hasA0WARHWissue() {
747-
return getPlatform() >= Xe_XeHPSDV;
751+
return getPlatform() >= Xe_XeHPSDV && getPlatform() < Xe2;
748752
}
749753

750754
bool hasFtoPackedHFMove() const { return getPlatform() >= Xe_DG2; }
@@ -896,8 +900,8 @@ bool supports4GRFAlign() const {
896900
return false;
897901
}
898902

899-
bool needA0WARForSend() const {
900-
return false;
903+
bool needA0WAR() const {
904+
return (getPlatform() >= Xe2);
901905
}
902906

903907
bool alwaysAllowGlobalFlagOpt() const {

visa/LocalScheduler/SWSB_G4IR.cpp

Lines changed: 67 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -569,7 +569,8 @@ void SBNode::finalizeDistanceType1(IR_Builder &builder,
569569
return;
570570
}
571571

572-
if (builder.hasA0WARHWissue() && (builder.hasThreeALUPipes() || builder.hasFourALUPipes())) {
572+
if (builder.hasA0WARHWissue() &&
573+
(builder.hasThreeALUPipes() || builder.hasFourALUPipes())) {
573574
G4_INST *inst = GetInstruction();
574575

575576
if (inst->getDst() && inst->getDst()->isDirectA0()) {
@@ -712,7 +713,8 @@ void SBNode::finalizeDistanceType2(IR_Builder &builder,
712713
return;
713714
}
714715

715-
if (builder.hasA0WARHWissue() && (builder.hasThreeALUPipes() || builder.hasFourALUPipes())) {
716+
if (builder.hasA0WARHWissue() &&
717+
(builder.hasThreeALUPipes() || builder.hasFourALUPipes())) {
716718
G4_INST *inst = GetInstruction();
717719

718720
if (inst->getDst() && inst->getDst()->isDirectA0()) {
@@ -870,7 +872,8 @@ void SBNode::finalizeDistanceType3(IR_Builder &builder,
870872
return;
871873
}
872874

873-
if (builder.hasA0WARHWissue() && (builder.hasThreeALUPipes() || builder.hasFourALUPipes())) {
875+
if (builder.hasA0WARHWissue() &&
876+
(builder.hasThreeALUPipes() || builder.hasFourALUPipes())) {
874877
G4_INST *inst = GetInstruction();
875878

876879
if (inst->getDst() && inst->getDst()->isDirectA0()) {
@@ -1258,18 +1261,33 @@ SBFootprint *G4_BB_SB::getFootprintForFlag(G4_Operand *opnd,
12581261
SBFootprint *G4_BB_SB::getFootprintForA0(G4_Operand *opnd,
12591262
Gen4_Operand_Number opnd_num,
12601263
G4_INST *inst) {
1264+
bool valid = true;
1265+
unsigned subRegNum = 0;
1266+
if (opnd->isSrcRegRegion()) {
1267+
G4_SrcRegRegion *srcRegRegion = opnd->asSrcRegRegion();
1268+
if (srcRegRegion->getRegAccess() == Direct) {
1269+
subRegNum = srcRegRegion->ExSubRegNum(valid);
1270+
} else {
1271+
subRegNum = srcRegRegion->ExIndSubRegNum(valid);
1272+
}
1273+
} else if (opnd->isDstRegRegion()) {
1274+
G4_DstRegRegion *dstRegRegion = opnd->asDstRegRegion();
1275+
if (dstRegRegion->getRegAccess() == Direct) {
1276+
subRegNum = dstRegRegion->ExSubRegNum(valid);
1277+
} else {
1278+
subRegNum = dstRegRegion->ExIndSubRegNum(valid);
1279+
}
1280+
} else {
1281+
vISA_ASSERT_UNREACHABLE("invalid A0 operand");
1282+
}
1283+
12611284
unsigned short LB = 0;
12621285
unsigned short RB = 0;
12631286
G4_Type type = opnd->getType();
1287+
G4_Type addrType = opnd->isIndirect() ? ADDR_REG_TYPE : opnd->getType();
1288+
LB = subRegNum * TypeSize(addrType);
1289+
RB = opnd->getRightBound() - opnd->getLeftBound() + LB;
12641290

1265-
bool valid = true;
1266-
unsigned subRegOff = opnd->getBase()->ExSubRegNum(valid);
1267-
G4_Type addrType = opnd->isIndirect() ? Type_UW : opnd->getType();
1268-
1269-
LB = (unsigned short)(subRegOff * TypeSize(addrType));
1270-
RB = (unsigned short)(LB + opnd->getRightBound() - opnd->getLeftBound());
1271-
1272-
// Updated to the bucket footprint
12731291
LB += (builder.kernel.getNumRegTotal() + builder.getNumScalarRegisters()) *
12741292
builder.numEltPerGRF<Type_UB>();
12751293
RB += (builder.kernel.getNumRegTotal() + builder.getNumScalarRegisters()) *
@@ -4694,8 +4712,7 @@ void SWSB::insertTokenSync() {
46944712
syncInst->setDistance(1);
46954713
if (kernel.fg.builder->hasThreeALUPipes() ||
46964714
kernel.fg.builder->hasFourALUPipes()) {
4697-
syncInst->setDistanceTypeXe(
4698-
G4_INST::DistanceType::DISTALL);
4715+
syncInst->setDistanceTypeXe(G4_INST::DistanceType::DISTALL);
46994716
}
47004717
}
47014718
}
@@ -5729,9 +5746,15 @@ bool G4_BB_SB::getFootprintForOperand(SBNode *node, G4_INST *inst,
57295746
}
57305747
}
57315748

5732-
if (builder.needA0WARForSend() && isA0) {
5749+
if (builder.needA0WAR() && isA0) {
57335750
footprint = getFootprintForA0(opnd, opndNum, inst);
5734-
node->setFootprint(footprint, opndNum);
5751+
if (opndNum == Opnd_dst && opnd->asDstRegRegion()->isIndirect()) {
5752+
// Indirect will only be used in the src0~src2, using Opnd_src4 as the
5753+
// indirect used in dst
5754+
node->setFootprint(footprint, Opnd_src4);
5755+
} else {
5756+
node->setFootprint(footprint, opndNum);
5757+
}
57355758
}
57365759

57375760
if (isS0Reg) {
@@ -6169,21 +6192,14 @@ void G4_BB_SB::setDistance(const SBFootprint *footprint, SBNode *node,
61696192
}
61706193

61716194
void G4_BB_SB::setSpecialDistance(SBNode *node) {
6172-
G4_INST *inst = node->GetInstruction();
6173-
if (!inst->getDst()) {
6174-
return;
6175-
}
6176-
6177-
if (inst->getDst()->isDirectA0()) {
6178-
SBDISTDEP_ITEM depItem;
6179-
depItem.liveNodePipe = PIPE_FLOAT;
6180-
depItem.nodePipe = node->ALUPipe;
6181-
depItem.operandType = PIPE_INT;
6182-
depItem.dstDep = false;
6183-
node->setDistance(1);
6184-
node->distDep.push_back(depItem);
6185-
node->setDistInfo(PIPE_FLOAT, 1);
6186-
}
6195+
SBDISTDEP_ITEM depItem;
6196+
depItem.liveNodePipe = PIPE_FLOAT;
6197+
depItem.nodePipe = node->ALUPipe;
6198+
depItem.operandType = PIPE_INT;
6199+
depItem.dstDep = false;
6200+
node->setDistance(1);
6201+
node->distDep.push_back(depItem);
6202+
node->setDistInfo(PIPE_FLOAT, 1);
61876203

61886204
return;
61896205
}
@@ -6675,8 +6691,17 @@ void G4_BB_SB::SBDDD(G4_BB *bb, LiveGRFBuckets *&LB,
66756691

66766692
if (builder.hasA0WARHWissue() &&
66776693
(builder.hasThreeALUPipes() || builder.hasFourALUPipes())) {
6678-
setSpecialDistance(node);
6694+
if (curInst->getDst() && curInst->getDst()->isDirectA0()) {
6695+
setSpecialDistance(node);
6696+
}
6697+
} else if (builder.needA0WAR()) {
6698+
if (!indexes->setFirstA0 && curInst->getDst() &&
6699+
curInst->getDst()->isDirectA0()) {
6700+
indexes->setFirstA0 = 1;
6701+
setSpecialDistance(node);
6702+
}
66796703
}
6704+
66806705
// Record the node IDs of the instructions in BB
66816706
if (first_node == INVALID_ID) {
66826707
first_node = nodeID;
@@ -7085,13 +7110,16 @@ void G4_BB_SB::SBDDD(G4_BB *bb, LiveGRFBuckets *&LB,
70857110
continue;
70867111
}
70877112

7088-
if (builder.needA0WARForSend() &&
7089-
curBucket == builder.kernel.getNumRegTotal() +
7090-
builder.getNumScalarRegisters()) {
7091-
if (!tokenHonourInstruction(liveInst) || dep != WAR ||
7092-
hasSameFunctionID(liveInst, curInst)) {
7093-
++bn_it;
7094-
continue;
7113+
if (builder.needA0WAR()) {
7114+
const int A0_start =
7115+
builder.kernel.getNumRegTotal() + builder.getNumScalarRegisters();
7116+
const int A0_end =
7117+
A0_start + builder.getNumAddrRegistersInGRFSizeSWSB() - 1;
7118+
if (curBucket >= A0_start && curBucket <= A0_end) {
7119+
if (dep != WAR) {
7120+
++bn_it;
7121+
continue;
7122+
}
70957123
}
70967124
}
70977125

@@ -7204,7 +7232,7 @@ void G4_BB_SB::SBDDD(G4_BB *bb, LiveGRFBuckets *&LB,
72047232
if (distanceHonourInstruction(liveInst)) {
72057233
if (dep == RAW &&
72067234
(curBucket < globalRegisterNum)) { // Only need track GRF
7207-
// RAW dependence
7235+
// RAW dependence
72087236
LB->killOperand(bn_it);
72097237
setDistance(curFootprint, node, liveNode, false);
72107238
liveNode->setInstKilled(true); // Instrtuction level kill

visa/LocalScheduler/SWSB_G4IR.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -609,6 +609,7 @@ typedef std::list<G4_BB_SB *> BB_SWSB_LIST;
609609
typedef BB_SWSB_LIST::iterator BB_SWSB_LIST_ITER;
610610

611611
typedef struct _SWSB_INDEXES {
612+
int setFirstA0 = 0;
612613
int instIndex = 0;
613614
int ALUIndex = 0;
614615
int integerIndex = 0;
@@ -682,7 +683,6 @@ class G4_BB_SB {
682683

683684
int send_start = -1;
684685
int send_end = -1;
685-
686686
unsigned loopStartBBID = -1; // The start BB ID of live range
687687
unsigned loopEndBBID = -1; // The start BB ID of live range
688688

@@ -733,8 +733,8 @@ class G4_BB_SB {
733733
last_send_node = -1;
734734
totalGRFNum = builder.kernel.getNumRegTotal();
735735
globalRegisterNum = totalGRFNum + builder.getNumScalarRegisters();
736-
if (builder.needA0WARForSend()) {
737-
globalRegisterNum += builder.getGRFNumOfAddrRegisters();
736+
if (builder.needA0WAR()) {
737+
globalRegisterNum += builder.getNumAddrRegistersInGRFSizeSWSB();
738738
}
739739

740740
SBDDD(bb, lb, globalLB, GRFAlignedGlobalSendsLB, SBNodes, SBSendNodes,
@@ -1175,8 +1175,8 @@ class SWSB {
11751175
{
11761176
globalRegisterNum =
11771177
kernel.getNumRegTotal() + k.fg.builder->getNumScalarRegisters();
1178-
if (k.fg.builder->needA0WARForSend()) {
1179-
globalRegisterNum += k.fg.builder->getGRFNumOfAddrRegisters();
1178+
if (k.fg.builder->needA0WAR()) {
1179+
globalRegisterNum += k.fg.builder->getNumAddrRegistersInGRFSizeSWSB();
11801180
}
11811181

11821182
indexes.instIndex = 0;

0 commit comments

Comments
 (0)