2525#include " llvm/IR/PassManager.h"
2626#include " llvm/IR/Type.h"
2727#include " llvm/Pass.h"
28+ #include " llvm/Support/Casting.h"
2829#include " llvm/Support/ErrorHandling.h"
2930#include " llvm/Support/MathExtras.h"
3031
@@ -70,15 +71,17 @@ static bool isIntrinsicExpansion(Function &F) {
7071 case Intrinsic::vector_reduce_add:
7172 case Intrinsic::vector_reduce_fadd:
7273 return true ;
73- case Intrinsic::dx_resource_load_typedbuffer:
74- // We need to handle doubles and vector of doubles.
75- return F.getReturnType ()
76- ->getStructElementType (0 )
77- ->getScalarType ()
78- ->isDoubleTy ();
79- case Intrinsic::dx_resource_store_typedbuffer:
80- // We need to handle doubles and vector of doubles.
81- return F.getFunctionType ()->getParamType (2 )->getScalarType ()->isDoubleTy ();
74+ case Intrinsic::dx_resource_load_typedbuffer: {
75+ // We need to handle i64, doubles, and vectors of them.
76+ Type *ScalarTy =
77+ F.getReturnType ()->getStructElementType (0 )->getScalarType ();
78+ return ScalarTy->isDoubleTy () || ScalarTy->isIntegerTy (64 );
79+ }
80+ case Intrinsic::dx_resource_store_typedbuffer: {
81+ // We need to handle i64 and doubles and vectors of i64 and doubles.
82+ Type *ScalarTy = F.getFunctionType ()->getParamType (2 )->getScalarType ();
83+ return ScalarTy->isDoubleTy () || ScalarTy->isIntegerTy (64 );
84+ }
8285 }
8386 return false ;
8487}
@@ -545,13 +548,15 @@ static bool expandTypedBufferLoadIntrinsic(CallInst *Orig) {
545548 IRBuilder<> Builder (Orig);
546549
547550 Type *BufferTy = Orig->getType ()->getStructElementType (0 );
548- assert (BufferTy->getScalarType ()->isDoubleTy () &&
549- " Only expand double or double2" );
551+ Type *ScalarTy = BufferTy->getScalarType ();
552+ bool IsDouble = ScalarTy->isDoubleTy ();
553+ assert (IsDouble || ScalarTy->isIntegerTy (64 ) &&
554+ " Only expand double or int64 scalars or vectors" );
550555
551556 unsigned ExtractNum = 2 ;
552557 if (auto *VT = dyn_cast<FixedVectorType>(BufferTy)) {
553558 assert (VT->getNumElements () == 2 &&
554- " TypedBufferLoad double vector has wrong size" );
559+ " TypedBufferLoad vector must be size 2 " );
555560 ExtractNum = 4 ;
556561 }
557562
@@ -570,22 +575,54 @@ static bool expandTypedBufferLoadIntrinsic(CallInst *Orig) {
570575 ExtractElements.push_back (
571576 Builder.CreateExtractElement (Extract, Builder.getInt32 (I)));
572577
573- // combine into double(s)
578+ // combine into double(s) or int64(s)
574579 Value *Result = PoisonValue::get (BufferTy);
575580 for (unsigned I = 0 ; I < ExtractNum; I += 2 ) {
576- Value *Dbl =
577- Builder.CreateIntrinsic (Builder.getDoubleTy (), Intrinsic::dx_asdouble,
578- {ExtractElements[I], ExtractElements[I + 1 ]});
581+ Value *Combined = nullptr ;
582+ if (IsDouble) {
583+ // For doubles, use dx_asdouble intrinsic
584+ Combined =
585+ Builder.CreateIntrinsic (Builder.getDoubleTy (), Intrinsic::dx_asdouble,
586+ {ExtractElements[I], ExtractElements[I + 1 ]});
587+ } else {
588+ // For int64, manually combine two int32s
589+ // First, zero-extend both values to i64
590+ Value *Lo = Builder.CreateZExt (ExtractElements[I], Builder.getInt64Ty ());
591+ Value *Hi =
592+ Builder.CreateZExt (ExtractElements[I + 1 ], Builder.getInt64Ty ());
593+ // Shift the high bits left by 32 bits
594+ Value *ShiftedHi = Builder.CreateShl (Hi, Builder.getInt64 (32 ));
595+ // OR the high and low bits together
596+ Combined = Builder.CreateOr (Lo, ShiftedHi);
597+ }
598+
579599 if (ExtractNum == 4 )
580- Result =
581- Builder. CreateInsertElement (Result, Dbl, Builder.getInt32 (I / 2 ));
600+ Result = Builder. CreateInsertElement (Result, Combined,
601+ Builder.getInt32 (I / 2 ));
582602 else
583- Result = Dbl ;
603+ Result = Combined ;
584604 }
585605
586606 Value *CheckBit = nullptr ;
587607 for (User *U : make_early_inc_range (Orig->users ())) {
588- auto *EVI = cast<ExtractValueInst>(U);
608+ if (auto *Ret = dyn_cast<ReturnInst>(U)) {
609+ // For return instructions, we need to handle the case where the function
610+ // is directly returning the result of the call
611+ Type *RetTy = Ret->getFunction ()->getReturnType ();
612+ Value *StructRet = PoisonValue::get (RetTy);
613+ StructRet = Builder.CreateInsertValue (StructRet, Result, {0 });
614+ Value *CheckBitForRet = Builder.CreateExtractValue (Load, {1 });
615+ StructRet = Builder.CreateInsertValue (StructRet, CheckBitForRet, {1 });
616+ Ret->setOperand (0 , StructRet);
617+ continue ;
618+ }
619+ auto *EVI = dyn_cast<ExtractValueInst>(U);
620+ if (!EVI) {
621+ // If it's not a ReturnInst or ExtractValueInst, we don't know how to
622+ // handle it
623+ llvm_unreachable (" Unexpected user of typedbufferload" );
624+ }
625+
589626 ArrayRef<unsigned > Indices = EVI->getIndices ();
590627 assert (Indices.size () == 1 );
591628
@@ -609,38 +646,90 @@ static bool expandTypedBufferStoreIntrinsic(CallInst *Orig) {
609646 IRBuilder<> Builder (Orig);
610647
611648 Type *BufferTy = Orig->getFunctionType ()->getParamType (2 );
612- assert (BufferTy->getScalarType ()->isDoubleTy () &&
613- " Only expand double or double2" );
649+ Type *ScalarTy = BufferTy->getScalarType ();
650+ bool IsDouble = ScalarTy->isDoubleTy ();
651+ assert ((IsDouble || ScalarTy->isIntegerTy (64 )) &&
652+ " Only expand double or int64 scalars or vectors" );
614653
615654 unsigned ExtractNum = 2 ;
616655 if (auto *VT = dyn_cast<FixedVectorType>(BufferTy)) {
617656 assert (VT->getNumElements () == 2 &&
618- " TypedBufferStore double vector has wrong size" );
657+ " TypedBufferStore vector must be size 2 " );
619658 ExtractNum = 4 ;
620659 }
660+ if (IsDouble) {
661+ Type *SplitElementTy = Builder.getInt32Ty ();
662+ if (ExtractNum == 4 )
663+ SplitElementTy = VectorType::get (SplitElementTy, 2 , false );
664+
665+ // Handle double type(s) - keep original behavior
666+ auto *SplitTy = llvm::StructType::get (SplitElementTy, SplitElementTy);
667+ Value *Split = Builder.CreateIntrinsic (SplitTy, Intrinsic::dx_splitdouble,
668+ {Orig->getOperand (2 )});
669+ // create our vector
670+ Value *LowBits = Builder.CreateExtractValue (Split, 0 );
671+ Value *HighBits = Builder.CreateExtractValue (Split, 1 );
672+ Value *Val;
673+ if (ExtractNum == 2 ) {
674+ Val = PoisonValue::get (VectorType::get (Builder.getInt32Ty (), 2 , false ));
675+ Val = Builder.CreateInsertElement (Val, LowBits, Builder.getInt32 (0 ));
676+ Val = Builder.CreateInsertElement (Val, HighBits, Builder.getInt32 (1 ));
677+ } else
678+ Val = Builder.CreateShuffleVector (LowBits, HighBits, {0 , 2 , 1 , 3 });
679+
680+ Builder.CreateIntrinsic (Builder.getVoidTy (),
681+ Intrinsic::dx_resource_store_typedbuffer,
682+ {Orig->getOperand (0 ), Orig->getOperand (1 ), Val});
683+ } else {
684+ // Handle int64 type(s)
685+ Value *InputVal = Orig->getOperand (2 );
686+ Value *Val;
687+
688+ if (ExtractNum == 4 ) {
689+ // Handle vector of int64
690+ Type *Int32x4Ty = VectorType::get (Builder.getInt32Ty (), 4 , false );
691+ Val = PoisonValue::get (Int32x4Ty);
692+
693+ for (unsigned I = 0 ; I < 2 ; ++I) {
694+ // Extract each int64 element
695+ Value *Int64Val =
696+ Builder.CreateExtractElement (InputVal, Builder.getInt32 (I));
697+
698+ // Get low 32 bits by truncating to i32
699+ Value *LowBits = Builder.CreateTrunc (Int64Val, Builder.getInt32Ty ());
700+
701+ // Get high 32 bits by shifting right by 32 and truncating
702+ Value *ShiftedVal = Builder.CreateLShr (Int64Val, Builder.getInt64 (32 ));
703+ Value *HighBits = Builder.CreateTrunc (ShiftedVal, Builder.getInt32Ty ());
704+
705+ // Insert into our final vector
706+ Val =
707+ Builder.CreateInsertElement (Val, LowBits, Builder.getInt32 (I * 2 ));
708+ Val = Builder.CreateInsertElement (Val, HighBits,
709+ Builder.getInt32 (I * 2 + 1 ));
710+ }
711+ } else {
712+ // Handle scalar int64
713+ Type *Int32x2Ty = VectorType::get (Builder.getInt32Ty (), 2 , false );
714+ Val = PoisonValue::get (Int32x2Ty);
715+
716+ // Get low 32 bits by truncating to i32
717+ Value *LowBits = Builder.CreateTrunc (InputVal, Builder.getInt32Ty ());
718+
719+ // Get high 32 bits by shifting right by 32 and truncating
720+ Value *ShiftedVal = Builder.CreateLShr (InputVal, Builder.getInt64 (32 ));
721+ Value *HighBits = Builder.CreateTrunc (ShiftedVal, Builder.getInt32Ty ());
722+
723+ // Insert into our final vector
724+ Val = Builder.CreateInsertElement (Val, LowBits, Builder.getInt32 (0 ));
725+ Val = Builder.CreateInsertElement (Val, HighBits, Builder.getInt32 (1 ));
726+ }
727+
728+ Builder.CreateIntrinsic (Builder.getVoidTy (),
729+ Intrinsic::dx_resource_store_typedbuffer,
730+ {Orig->getOperand (0 ), Orig->getOperand (1 ), Val});
731+ }
621732
622- Type *SplitElementTy = Builder.getInt32Ty ();
623- if (ExtractNum == 4 )
624- SplitElementTy = VectorType::get (SplitElementTy, 2 , false );
625-
626- // split our double(s)
627- auto *SplitTy = llvm::StructType::get (SplitElementTy, SplitElementTy);
628- Value *Split = Builder.CreateIntrinsic (SplitTy, Intrinsic::dx_splitdouble,
629- Orig->getOperand (2 ));
630- // create our vector
631- Value *LowBits = Builder.CreateExtractValue (Split, 0 );
632- Value *HighBits = Builder.CreateExtractValue (Split, 1 );
633- Value *Val;
634- if (ExtractNum == 2 ) {
635- Val = PoisonValue::get (VectorType::get (SplitElementTy, 2 , false ));
636- Val = Builder.CreateInsertElement (Val, LowBits, Builder.getInt32 (0 ));
637- Val = Builder.CreateInsertElement (Val, HighBits, Builder.getInt32 (1 ));
638- } else
639- Val = Builder.CreateShuffleVector (LowBits, HighBits, {0 , 2 , 1 , 3 });
640-
641- Builder.CreateIntrinsic (Builder.getVoidTy (),
642- Intrinsic::dx_resource_store_typedbuffer,
643- {Orig->getOperand (0 ), Orig->getOperand (1 ), Val});
644733 Orig->eraseFromParent ();
645734 return true ;
646735}
0 commit comments