@@ -282,7 +282,7 @@ StatusPrivArr2Reg LowerGEPForPrivMem::CheckIfAllocaPromotable(llvm::AllocaInst *
282
282
283
283
allowedAllocaSizeInBytes = (allowedAllocaSizeInBytes * 8 ) / SIMDSize;
284
284
}
285
- SOALayoutChecker checker (*pAlloca, m_ctx->type == ShaderType::OPENCL_SHADER);
285
+ SOALayoutChecker checker (*pAlloca, m_ctx->type == ShaderType::OPENCL_SHADER, true );
286
286
SOALayoutInfo SOAInfo = checker.getOrGatherInfo ();
287
287
if (!SOAInfo.canUseSOALayout ) {
288
288
return StatusPrivArr2Reg::CannotUseSOALayout;
@@ -357,7 +357,7 @@ StatusPrivArr2Reg LowerGEPForPrivMem::CheckIfAllocaPromotable(llvm::AllocaInst *
357
357
return StatusPrivArr2Reg::OK;
358
358
}
359
359
360
- SOALayoutChecker::SOALayoutChecker (AllocaInst &allocaToCheck, bool isOCL) : allocaRef(allocaToCheck) {
360
+ SOALayoutChecker::SOALayoutChecker (AllocaInst &allocaToCheck, bool isOCL, bool mismatchedWidthsSupport ) : allocaRef(allocaToCheck), mismatchedWidthsSupport(mismatchedWidthsSupport ) {
361
361
auto F = allocaToCheck.getParent ()->getParent ();
362
362
pDL = &F->getParent ()->getDataLayout ();
363
363
newAlgoControl = IGC_GET_FLAG_VALUE (EnablePrivMemNewSOATranspose);
@@ -571,9 +571,12 @@ bool IGC::SOALayoutChecker::MismatchDetected(Instruction &I) {
571
571
return false ;
572
572
573
573
Type *allocaTy = allocaRef.getAllocatedType ();
574
- bool allocaIsVecOrArr = allocaTy->isVectorTy () || allocaTy->isArrayTy ();
574
+ bool allocaIsVecOrArrOrStruct = allocaTy->isVectorTy () || allocaTy->isArrayTy () || allocaTy-> isStructTy ();
575
575
576
- if (!allocaIsVecOrArr)
576
+ if (!allocaIsVecOrArrOrStruct)
577
+ return false ;
578
+
579
+ if (mismatchedWidthsSupport)
577
580
return false ;
578
581
579
582
bool useOldAlgorithm = !useNewAlgo (pInfo->baseType );
@@ -593,12 +596,24 @@ bool IGC::SOALayoutChecker::MismatchDetected(Instruction &I) {
593
596
allocaTy = arrTy->getElementType ();
594
597
} else if (auto *vec = dyn_cast<IGCLLVM::FixedVectorType>(allocaTy)) {
595
598
allocaTy = vec->getElementType ();
599
+ } else if (auto *strct = dyn_cast<StructType>(allocaTy)){
600
+ if (auto *arrTy = dyn_cast<ArrayType>(strct->getStructElementType (0 ))) {
601
+ allocaTy = arrTy->getElementType ();
602
+ } else if (auto *vec = dyn_cast<IGCLLVM::FixedVectorType>(strct->getStructElementType (0 ))){
603
+ allocaTy = vec->getElementType ();
604
+ }
596
605
}
597
606
598
607
if (auto *arrTy = dyn_cast<ArrayType>(pUserTy)) {
599
608
pUserTy = arrTy->getElementType ();
600
609
} else if (auto *vec = dyn_cast<IGCLLVM::FixedVectorType>(pUserTy)) {
601
610
pUserTy = vec->getElementType ();
611
+ } else if (auto *strct = dyn_cast<StructType>(pUserTy)){
612
+ if (auto *arrTy = dyn_cast<ArrayType>(strct->getStructElementType (0 ))) {
613
+ pUserTy = arrTy->getElementType ();
614
+ } else if (auto *vec = dyn_cast<IGCLLVM::FixedVectorType>(strct->getStructElementType (0 ))){
615
+ pUserTy = vec->getElementType ();
616
+ }
602
617
}
603
618
}
604
619
0 commit comments