@@ -1633,7 +1633,7 @@ class CBSIntrinsic {
16331633 return Builder.CreateSelect (EnoughIterationsLeft, Builder.getInt64 (SGSize), InnerDimIterationsLeft);
16341634 }();
16351635
1636- res = vectorizeUniformValue (Storage, Builder, Intrinsic, Builder. CreateTruncOrBitCast ( SGIterations, Storage-> getType ()) );
1636+ res = vectorizeUniformValue (Storage, Builder, Intrinsic, SGIterations);
16371637 } else {
16381638 auto *VType = llvm::VectorType::get (Op0->getType (), llvm::ElementCount::getFixed (SGSize));
16391639
@@ -1748,29 +1748,21 @@ class CBSIntrinsic {
17481748class ReduceIntrinsic final : public CBSIntrinsic {
17491749 std::string_view getName () override { return " __cbs_reduce" ; }
17501750
1751+ std::string getTypeStr (llvm::Type* Type) {
1752+ std::string type_str;
1753+ llvm::raw_string_ostream rso (type_str);
1754+ Type->print (rso);
1755+ return rso.str ();
1756+ }
1757+
17511758 std::pair<llvm::Value *, Shape> vectorizeUniformValue (llvm::Value *Storage,
17521759 llvm::IRBuilder<> &Builder,
17531760 llvm::CallInst &Intrinsic, llvm::Value* NumberOfLoopIterationsLeft) override {
17541761 auto *Idx = llvm::dyn_cast<llvm::ConstantInt>(Intrinsic.getOperand (1 ));
1762+ auto *Type = Storage->getType ();
17551763 assert (Idx and " Op must be constant int" );
17561764 const auto v = Idx->getSExtValue ();
1757- // ADD
1758- if (v == 0 ) {
1759- return {Builder.CreateMul (Storage, NumberOfLoopIterationsLeft),
1760- Shape::UNIFORM};
1761- }
1762- // MUL
1763- if (v == 1 ) {
1764- // POW(, 32)
1765-
1766- // TODO
1767- assert (false );
1768- llvm::Value *result = Storage;
1769- for (auto i = 1ul ; i < SGSize; ++i) {
1770- result = Builder.CreateMul (result, Storage);
1771- }
1772- return {result, Shape::UNIFORM};
1773- }
1765+ const bool isInt = Type->isIntegerTy ();
17741766 // min
17751767 if (v == 2 ) {
17761768 return {Storage, Shape::UNIFORM};
@@ -1779,6 +1771,37 @@ class ReduceIntrinsic final : public CBSIntrinsic {
17791771 if (v == 3 ) {
17801772 return {Storage, Shape::UNIFORM};
17811773 }
1774+
1775+ if (v == 0 ) {
1776+ // ADD
1777+ if (isInt)
1778+ return {Builder.CreateMul (Storage,
1779+ Builder.CreateIntCast (NumberOfLoopIterationsLeft, Type, false )),
1780+ Shape::UNIFORM};
1781+ return {Builder.CreateFMul (
1782+ Storage, Builder.CreateUIToFP (NumberOfLoopIterationsLeft, Storage->getType ())),
1783+ Shape::UNIFORM};
1784+ }
1785+ if (v == 1 ) {
1786+ auto M = Intrinsic.getParent ()->getParent ()->getParent ();
1787+ if (not isInt) {
1788+ auto *Pow =
1789+ llvm::Intrinsic::getDeclaration (M, llvm::Intrinsic::powi, {Type, Builder.getInt32Ty ()});
1790+ llvm::Value *result = Storage;
1791+ llvm::SmallVector<llvm::Value *> Args{
1792+ result, Builder.CreateIntCast (NumberOfLoopIterationsLeft, Builder.getInt32Ty (), false )};
1793+ result = Builder.CreateCall (Pow, Args);
1794+ return {result, Shape::UNIFORM};
1795+ }
1796+ // WTF LLVM does not have integer pow intrinsic only floating point
1797+ auto *Pow = createPowFunction (M, Type);
1798+ llvm::Value *result = Storage;
1799+ llvm::SmallVector<llvm::Value *> Args{
1800+ result, Builder.CreateIntCast (NumberOfLoopIterationsLeft, Builder.getInt32Ty (), false )};
1801+ result = Builder.CreateCall (Pow, Args);
1802+ return {result, Shape::UNIFORM};
1803+ }
1804+
17821805 assert (false );
17831806 return {};
17841807 }
@@ -1852,7 +1875,59 @@ class ReduceIntrinsic final : public CBSIntrinsic {
18521875 return {};
18531876 return llvm::ConstantInt::get (Type, 0 );
18541877 }
1855- };
1878+
1879+ llvm::Function* createPowFunction (llvm::Module* module , llvm::Type* Type) {
1880+ llvm::LLVMContext& context = module ->getContext ();
1881+ llvm::IRBuilder<> builder (context);
1882+
1883+ llvm::FunctionType* funcType = llvm::FunctionType::get (Type, {Type, builder.getInt32Ty ()}, false );
1884+ auto powFunction = llvm::dyn_cast<llvm::Function>(module ->getOrInsertFunction (" pow." + getTypeStr (Type), funcType).getCallee ());
1885+
1886+ // Create a basic block and set the insert point
1887+ llvm::BasicBlock* entry = llvm::BasicBlock::Create (context, " entry" , powFunction);
1888+ builder.SetInsertPoint (entry);
1889+
1890+ // Get function arguments
1891+ auto args = powFunction->arg_begin ();
1892+ llvm::Value* base = args++;
1893+ base->setName (" base" );
1894+ llvm::Value* exponent = args++;
1895+ exponent->setName (" exponent" );
1896+
1897+ // Initialize loop variables
1898+ llvm::AllocaInst* result = builder.CreateAlloca (Type, nullptr , " result" );
1899+ llvm::Value* counter = builder.CreateAlloca (exponent->getType (), nullptr , " counter" );
1900+ builder.CreateStore (builder.getIntN (exponent->getType ()->getIntegerBitWidth (), 0 ), counter);
1901+ builder.CreateStore (builder.getIntN (Type->getIntegerBitWidth (), 1 ), result);
1902+
1903+ // Create loop blocks
1904+ llvm::BasicBlock* loopBB = llvm::BasicBlock::Create (context, " loop" , powFunction);
1905+ llvm::BasicBlock* afterLoopBB = llvm::BasicBlock::Create (context, " afterloop" , powFunction);
1906+
1907+ // Branch to loop block
1908+ builder.CreateBr (loopBB);
1909+ builder.SetInsertPoint (loopBB);
1910+
1911+ // Load counter value
1912+ llvm::Value* counterValue = builder.CreateLoad (exponent->getType (), counter, " counterValue" );
1913+
1914+ // Loop body
1915+ llvm::Value* nextCounter = builder.CreateAdd (counterValue, builder.getIntN (exponent->getType ()->getIntegerBitWidth (), 1 ), " nextcounter" );
1916+ builder.CreateStore (nextCounter, counter);
1917+ auto resultX = builder.CreateMul (builder.CreateLoad (result->getAllocatedType (), result), base, " result" );
1918+ builder.CreateStore (resultX, result);
1919+
1920+ llvm::Value* cond = builder.CreateICmpULT (nextCounter, exponent, " loopcond" );
1921+ builder.CreateCondBr (cond, loopBB, afterLoopBB);
1922+
1923+ builder.SetInsertPoint (afterLoopBB);
1924+
1925+ builder.CreateRet (builder.CreateLoad (result->getAllocatedType (), result));
1926+
1927+ return powFunction;
1928+ }
1929+
1930+ };
18561931
18571932template <bool Left> class Shift final : public CBSIntrinsic {
18581933 std::string_view getName () override { return Left ? " __cbs_shift_left" : " __cbs_shift_right" ; }
0 commit comments