@@ -446,16 +446,23 @@ static Type *getResolvedVectorElementType(Type *matrixType) {
446446 return ty->getElementType ();
447447}
448448
449- static int getSliceSize (const JointMatrixTypeDescription *desc) {
449+ static int getSliceSize (const JointMatrixTypeDescription *desc, Type *matTy) {
450+ IGCLLVM::FixedVectorType *ty = dyn_cast<IGCLLVM::FixedVectorType>(matTy);
451+ IGC_ASSERT_MESSAGE (ty, " Expecting vector type in calculating slice size" );
452+
453+ IntegerType *vecElemType = dyn_cast<IntegerType>(ty->getElementType ());
454+ IGC_ASSERT_MESSAGE (vecElemType, " Expecting integer type for vector element." );
455+
456+ unsigned contribTypeWidth = vecElemType->getBitWidth ();
450457 if (desc->layout == LayoutRowMajor) {
451458 return desc->rows ;
452459 }
453460 if (desc->bitWidth != 0 ) {
454461 if (desc->layout == LayoutPackedA) {
455- return desc->rows * (32 / desc->bitWidth );
462+ return desc->rows * (contribTypeWidth / desc->bitWidth );
456463 }
457464 if (desc->layout == LayoutPackedB) {
458- return 8 * (32 / desc->bitWidth );
465+ return 8 * (contribTypeWidth / desc->bitWidth );
459466 }
460467 }
461468 IGC_ASSERT_MESSAGE (true , " Unexpected matrix layout." );
@@ -511,7 +518,7 @@ Value *JointMatrixFuncsResolutionPass::ResolveFill(CallInst *CI) {
511518 Type *matTy = ResolveType (CI->getType (), &desc);
512519
513520 IRBuilder builder (CI);
514- const int sliceSize = getSliceSize (&desc);
521+ const int sliceSize = getSliceSize (&desc, matTy );
515522 const int vectorSize = getResolvedVectorSize (matTy);
516523 /* Case with packing: */
517524 if (sliceSize > vectorSize) {
@@ -523,6 +530,14 @@ Value *JointMatrixFuncsResolutionPass::ResolveFill(CallInst *CI) {
523530 IGC_ASSERT_MESSAGE (false , " Malformed matrix slice." );
524531 }
525532
533+ if (fillValue->getType ()->isPointerTy ())
534+ {
535+ IntegerType *vectorElementType = dyn_cast<IntegerType>(getResolvedVectorElementType (matTy));
536+ PointerType *PT = dyn_cast<PointerType>(fillValue->getType ());
537+ fillValue = builder.CreateBitCast (fillValue, PointerType::get (vectorElementType, PT->getAddressSpace ()));
538+ fillValue = builder.CreateLoad (vectorElementType, fillValue);
539+ }
540+
526541 Value *slice = UndefValue::get (matTy);
527542 for (int i = 0 ; i < vectorSize; i++) {
528543 slice = builder.CreateInsertElement (slice, fillValue, i);
@@ -534,9 +549,9 @@ Value *JointMatrixFuncsResolutionPass::ResolveFill(CallInst *CI) {
534549
535550Value *JointMatrixFuncsResolutionPass::ResolveWILength (CallInst *CI) {
536551 JointMatrixTypeDescription desc;
537- ResolveType (CI->getArgOperand (0 )->getType (), &desc);
552+ Type *matTy = ResolveType (CI->getArgOperand (0 )->getType (), &desc);
538553
539- const int sliceSize = getSliceSize (&desc);
554+ const int sliceSize = getSliceSize (&desc, matTy );
540555 Value *lenght = ConstantInt::get (CI->getType (), sliceSize, " matrix.slice.size" );
541556
542557 CI->replaceAllUsesWith (lenght);
@@ -546,8 +561,8 @@ Value *JointMatrixFuncsResolutionPass::ResolveWILength(CallInst *CI) {
546561
547562template <class BuilderT >
548563static Value *createSliceExtract
549- (BuilderT *builder, Value *matrix, Value *index, const JointMatrixTypeDescription *desc) {
550- const int sliceSize = getSliceSize (desc);
564+ (BuilderT *builder, Value *matrix, Value *index, const JointMatrixTypeDescription *desc, Type *matTy ) {
565+ const int sliceSize = getSliceSize (desc, matTy );
551566 const int vectorSize = getResolvedVectorSize (matrix->getType ());
552567 /* Unpacking: */
553568 if (sliceSize > vectorSize) {
@@ -568,12 +583,12 @@ Value *JointMatrixFuncsResolutionPass::ResolveSliceInsert(CallInst *CI) {
568583 IGCLLVM::FixedVectorType *matTy = dyn_cast<IGCLLVM::FixedVectorType>(rawMatTy);
569584
570585 IRBuilder builder (CI);
571- const int sliceSize = getSliceSize (&desc);
586+ const int sliceSize = getSliceSize (&desc, rawMatTy );
572587 const int vectorSize = getResolvedVectorSize (matTy);
573588
574589 Value *slice = nullptr ;
575590 if (sliceSize > vectorSize) {
576- Value *element = createSliceExtract (&builder, matrix, index, &desc);
591+ Value *element = createSliceExtract (&builder, matrix, index, &desc, rawMatTy );
577592 if (!isa<IntegerType>(element->getType ())) {
578593 unsigned vecElemSize = matTy->getElementType ()->getScalarSizeInBits ();
579594 element = builder.CreateBitCast (element, Type::getIntNTy (builder.getContext (), vecElemSize));
@@ -603,6 +618,10 @@ Value *JointMatrixFuncsResolutionPass::ResolveSliceInsert(CallInst *CI) {
603618 component = builder.CreateShl (component, offset);
604619 component = builder.CreateOr (element, component);
605620 }
621+
622+ IntegerType *vectorElementType = dyn_cast<IntegerType>(getResolvedVectorElementType (rawMatTy));
623+ component = builder.CreateBitCast (component, vectorElementType);
624+
606625 slice = builder.CreateInsertElement (matrix, component, index);
607626
608627 InstsToErase.insert (CI);
@@ -617,9 +636,9 @@ Value *JointMatrixFuncsResolutionPass::ResolveSliceExtract(CallInst *CI) {
617636 Type *matTy = ResolveType (CI->getArgOperand (0 )->getType (), &desc);
618637
619638 IRBuilder builder (CI);
620- Value *element = createSliceExtract (&builder, matrix, index, &desc);
639+ Value *element = createSliceExtract (&builder, matrix, index, &desc, matTy );
621640 /* Unpacking: */
622- const int sliceSize = getSliceSize (&desc);
641+ const int sliceSize = getSliceSize (&desc, matTy );
623642 const int vectorSize = getResolvedVectorSize (matTy);
624643 if (sliceSize > vectorSize) {
625644 index = builder.CreateTruncOrBitCast (index, element->getType ());
@@ -634,6 +653,10 @@ Value *JointMatrixFuncsResolutionPass::ResolveSliceExtract(CallInst *CI) {
634653 element = builder.CreateBitCast (element, CI->getType ());
635654 }
636655
656+ // We need the bitcast, especially for half, as the function call that is
657+ // being replaces has a half return type and the vectorElementType is i16
658+ element = builder.CreateBitCast (element, CI->getType ());
659+
637660 CI->replaceAllUsesWith (element);
638661 InstsToErase.insert (CI);
639662 return element;
0 commit comments