@@ -651,85 +651,77 @@ static bool expandTypedBufferStoreIntrinsic(CallInst *Orig) {
651651 assert ((IsDouble || ScalarTy->isIntegerTy (64 )) &&
652652 " Only expand double or int64 scalars or vectors" );
653653
654- unsigned ExtractNum = 2 ;
655- if (auto *VT = dyn_cast<FixedVectorType>(BufferTy)) {
656- assert (VT->getNumElements () == 2 &&
654+ // Determine if we're dealing with a vector or scalar
655+ bool IsVector = isa<FixedVectorType>(BufferTy);
656+ if (IsVector) {
657+ assert (cast<FixedVectorType>(BufferTy)->getNumElements () == 2 &&
657658 " TypedBufferStore vector must be size 2" );
658- ExtractNum = 4 ;
659659 }
660+
661+ // Create the appropriate vector type for the result
662+ Type *Int32Ty = Builder.getInt32Ty ();
663+ Type *ResultTy = VectorType::get (Int32Ty, IsVector ? 4 : 2 , false );
664+ Value *Val = PoisonValue::get (ResultTy);
665+
666+ // Split the 64-bit values into 32-bit components
660667 if (IsDouble) {
661- Type *SplitElementTy = Builder.getInt32Ty ();
662- if (ExtractNum == 4 )
668+ // Handle double type(s)
669+ Type *SplitElementTy = Int32Ty;
670+ if (IsVector)
663671 SplitElementTy = VectorType::get (SplitElementTy, 2 , false );
664672
665- // Handle double type(s) - keep original behavior
666673 auto *SplitTy = llvm::StructType::get (SplitElementTy, SplitElementTy);
667674 Value *Split = Builder.CreateIntrinsic (SplitTy, Intrinsic::dx_splitdouble,
668675 {Orig->getOperand (2 )});
669- // create our vector
670676 Value *LowBits = Builder.CreateExtractValue (Split, 0 );
671677 Value *HighBits = Builder.CreateExtractValue (Split, 1 );
672- Value *Val;
673- if (ExtractNum == 2 ) {
674- Val = PoisonValue::get (VectorType::get (Builder.getInt32Ty (), 2 , false ));
678+
679+ if (IsVector) {
680+ // For vector doubles, use shuffle to create the final vector
681+ Val = Builder.CreateShuffleVector (LowBits, HighBits, {0 , 2 , 1 , 3 });
682+ } else {
683+ // For scalar doubles, insert the elements
675684 Val = Builder.CreateInsertElement (Val, LowBits, Builder.getInt32 (0 ));
676685 Val = Builder.CreateInsertElement (Val, HighBits, Builder.getInt32 (1 ));
677- } else
678- Val = Builder.CreateShuffleVector (LowBits, HighBits, {0 , 2 , 1 , 3 });
679-
680- Builder.CreateIntrinsic (Builder.getVoidTy (),
681- Intrinsic::dx_resource_store_typedbuffer,
682- {Orig->getOperand (0 ), Orig->getOperand (1 ), Val});
686+ }
683687 } else {
684688 // Handle int64 type(s)
685689 Value *InputVal = Orig->getOperand (2 );
686- Value *Val;
687690
688- if (ExtractNum == 4 ) {
691+ if (IsVector ) {
689692 // Handle vector of int64
690- Type *Int32x4Ty = VectorType::get (Builder.getInt32Ty (), 4 , false );
691- Val = PoisonValue::get (Int32x4Ty);
692-
693693 for (unsigned I = 0 ; I < 2 ; ++I) {
694694 // Extract each int64 element
695695 Value *Int64Val =
696696 Builder.CreateExtractElement (InputVal, Builder.getInt32 (I));
697697
698- // Get low 32 bits by truncating to i32
699- Value *LowBits = Builder.CreateTrunc (Int64Val, Builder.getInt32Ty ());
700-
701- // Get high 32 bits by shifting right by 32 and truncating
698+ // Split into low and high 32-bit parts
699+ Value *LowBits = Builder.CreateTrunc (Int64Val, Int32Ty);
702700 Value *ShiftedVal = Builder.CreateLShr (Int64Val, Builder.getInt64 (32 ));
703- Value *HighBits = Builder.CreateTrunc (ShiftedVal, Builder. getInt32Ty () );
701+ Value *HighBits = Builder.CreateTrunc (ShiftedVal, Int32Ty );
704702
705- // Insert into our final vector
703+ // Insert into result vector
706704 Val =
707705 Builder.CreateInsertElement (Val, LowBits, Builder.getInt32 (I * 2 ));
708706 Val = Builder.CreateInsertElement (Val, HighBits,
709707 Builder.getInt32 (I * 2 + 1 ));
710708 }
711709 } else {
712710 // Handle scalar int64
713- Type *Int32x2Ty = VectorType::get (Builder.getInt32Ty (), 2 , false );
714- Val = PoisonValue::get (Int32x2Ty);
715-
716- // Get low 32 bits by truncating to i32
717- Value *LowBits = Builder.CreateTrunc (InputVal, Builder.getInt32Ty ());
718-
719- // Get high 32 bits by shifting right by 32 and truncating
711+ Value *LowBits = Builder.CreateTrunc (InputVal, Int32Ty);
720712 Value *ShiftedVal = Builder.CreateLShr (InputVal, Builder.getInt64 (32 ));
721- Value *HighBits = Builder.CreateTrunc (ShiftedVal, Builder. getInt32Ty () );
713+ Value *HighBits = Builder.CreateTrunc (ShiftedVal, Int32Ty );
722714
723- // Insert into our final vector
724715 Val = Builder.CreateInsertElement (Val, LowBits, Builder.getInt32 (0 ));
725716 Val = Builder.CreateInsertElement (Val, HighBits, Builder.getInt32 (1 ));
726717 }
727-
728- Builder.CreateIntrinsic (Builder.getVoidTy (),
729- Intrinsic::dx_resource_store_typedbuffer,
730- {Orig->getOperand (0 ), Orig->getOperand (1 ), Val});
731718 }
732719
720+ // Create the final intrinsic call
721+ Builder.CreateIntrinsic (Builder.getVoidTy (),
722+ Intrinsic::dx_resource_store_typedbuffer,
723+ {Orig->getOperand (0 ), Orig->getOperand (1 ), Val});
724+
733725 Orig->eraseFromParent ();
734726 return true ;
735727}
0 commit comments