Skip to content

Commit 5c410e3

Browse files
author
Moritz
committed
Implemented todo
1 parent c1bf668 commit 5c410e3

File tree

1 file changed

+94
-19
lines changed

1 file changed

+94
-19
lines changed

src/compiler/cbs/SubCfgFormation.cpp

Lines changed: 94 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1633,7 +1633,7 @@ class CBSIntrinsic {
16331633
return Builder.CreateSelect(EnoughIterationsLeft, Builder.getInt64(SGSize), InnerDimIterationsLeft);
16341634
}();
16351635

1636-
res = vectorizeUniformValue(Storage, Builder, Intrinsic, Builder.CreateTruncOrBitCast(SGIterations, Storage->getType()));
1636+
res = vectorizeUniformValue(Storage, Builder, Intrinsic, SGIterations);
16371637
} else {
16381638
auto *VType = llvm::VectorType::get(Op0->getType(), llvm::ElementCount::getFixed(SGSize));
16391639

@@ -1748,29 +1748,21 @@ class CBSIntrinsic {
17481748
class ReduceIntrinsic final : public CBSIntrinsic {
17491749
std::string_view getName() override { return "__cbs_reduce"; }
17501750

1751+
std::string getTypeStr(llvm::Type* Type) {
1752+
std::string type_str;
1753+
llvm::raw_string_ostream rso(type_str);
1754+
Type->print(rso);
1755+
return rso.str();
1756+
}
1757+
17511758
std::pair<llvm::Value *, Shape> vectorizeUniformValue(llvm::Value *Storage,
17521759
llvm::IRBuilder<> &Builder,
17531760
llvm::CallInst &Intrinsic, llvm::Value* NumberOfLoopIterationsLeft) override {
17541761
auto *Idx = llvm::dyn_cast<llvm::ConstantInt>(Intrinsic.getOperand(1));
1762+
auto *Type = Storage->getType();
17551763
assert(Idx and "Op must be constant int");
17561764
const auto v = Idx->getSExtValue();
1757-
// ADD
1758-
if (v == 0) {
1759-
return {Builder.CreateMul(Storage, NumberOfLoopIterationsLeft),
1760-
Shape::UNIFORM};
1761-
}
1762-
// MUL
1763-
if (v == 1) {
1764-
// POW(, 32)
1765-
1766-
// TODO
1767-
assert(false);
1768-
llvm::Value *result = Storage;
1769-
for (auto i = 1ul; i < SGSize; ++i) {
1770-
result = Builder.CreateMul(result, Storage);
1771-
}
1772-
return {result, Shape::UNIFORM};
1773-
}
1765+
const bool isInt = Type->isIntegerTy();
17741766
// min
17751767
if (v == 2) {
17761768
return {Storage, Shape::UNIFORM};
@@ -1779,6 +1771,37 @@ class ReduceIntrinsic final : public CBSIntrinsic {
17791771
if (v == 3) {
17801772
return {Storage, Shape::UNIFORM};
17811773
}
1774+
1775+
if (v == 0) {
1776+
// ADD
1777+
if (isInt)
1778+
return {Builder.CreateMul(Storage,
1779+
Builder.CreateIntCast(NumberOfLoopIterationsLeft, Type, false)),
1780+
Shape::UNIFORM};
1781+
return {Builder.CreateFMul(
1782+
Storage, Builder.CreateUIToFP(NumberOfLoopIterationsLeft, Storage->getType())),
1783+
Shape::UNIFORM};
1784+
}
1785+
if (v == 1) {
1786+
auto M = Intrinsic.getParent()->getParent()->getParent();
1787+
if (not isInt) {
1788+
auto *Pow =
1789+
llvm::Intrinsic::getDeclaration(M, llvm::Intrinsic::powi, {Type, Builder.getInt32Ty()});
1790+
llvm::Value *result = Storage;
1791+
llvm::SmallVector<llvm::Value *> Args{
1792+
result, Builder.CreateIntCast(NumberOfLoopIterationsLeft, Builder.getInt32Ty(), false)};
1793+
result = Builder.CreateCall(Pow, Args);
1794+
return {result, Shape::UNIFORM};
1795+
}
1796+
// WTF LLVM does not have integer pow intrinsic only floating point
1797+
auto *Pow = createPowFunction(M, Type);
1798+
llvm::Value *result = Storage;
1799+
llvm::SmallVector<llvm::Value *> Args{
1800+
result, Builder.CreateIntCast(NumberOfLoopIterationsLeft, Builder.getInt32Ty(), false)};
1801+
result = Builder.CreateCall(Pow, Args);
1802+
return {result, Shape::UNIFORM};
1803+
}
1804+
17821805
assert(false);
17831806
return {};
17841807
}
@@ -1852,7 +1875,59 @@ class ReduceIntrinsic final : public CBSIntrinsic {
18521875
return {};
18531876
return llvm::ConstantInt::get(Type, 0);
18541877
}
1855-
};
1878+
1879+
llvm::Function* createPowFunction(llvm::Module* module, llvm::Type* Type) {
1880+
llvm::LLVMContext& context = module->getContext();
1881+
llvm::IRBuilder<> builder(context);
1882+
1883+
llvm::FunctionType* funcType = llvm::FunctionType::get(Type, {Type, builder.getInt32Ty()}, false);
1884+
auto powFunction = llvm::dyn_cast<llvm::Function>(module->getOrInsertFunction("pow." + getTypeStr(Type), funcType).getCallee());
1885+
1886+
// Create a basic block and set the insert point
1887+
llvm::BasicBlock* entry = llvm::BasicBlock::Create(context, "entry", powFunction);
1888+
builder.SetInsertPoint(entry);
1889+
1890+
// Get function arguments
1891+
auto args = powFunction->arg_begin();
1892+
llvm::Value* base = args++;
1893+
base->setName("base");
1894+
llvm::Value* exponent = args++;
1895+
exponent->setName("exponent");
1896+
1897+
// Initialize loop variables
1898+
llvm::AllocaInst* result = builder.CreateAlloca(Type, nullptr, "result");
1899+
llvm::Value* counter = builder.CreateAlloca(exponent->getType(), nullptr, "counter");
1900+
builder.CreateStore(builder.getIntN(exponent->getType()->getIntegerBitWidth(), 0), counter);
1901+
builder.CreateStore(builder.getIntN(Type->getIntegerBitWidth(), 1), result);
1902+
1903+
// Create loop blocks
1904+
llvm::BasicBlock* loopBB = llvm::BasicBlock::Create(context, "loop", powFunction);
1905+
llvm::BasicBlock* afterLoopBB = llvm::BasicBlock::Create(context, "afterloop", powFunction);
1906+
1907+
// Branch to loop block
1908+
builder.CreateBr(loopBB);
1909+
builder.SetInsertPoint(loopBB);
1910+
1911+
// Load counter value
1912+
llvm::Value* counterValue = builder.CreateLoad(exponent->getType(), counter, "counterValue");
1913+
1914+
// Loop body
1915+
llvm::Value* nextCounter = builder.CreateAdd(counterValue, builder.getIntN(exponent->getType()->getIntegerBitWidth(), 1), "nextcounter");
1916+
builder.CreateStore(nextCounter, counter);
1917+
auto resultX = builder.CreateMul(builder.CreateLoad(result->getAllocatedType(), result), base, "result");
1918+
builder.CreateStore(resultX, result);
1919+
1920+
llvm::Value* cond = builder.CreateICmpULT(nextCounter, exponent, "loopcond");
1921+
builder.CreateCondBr(cond, loopBB, afterLoopBB);
1922+
1923+
builder.SetInsertPoint(afterLoopBB);
1924+
1925+
builder.CreateRet(builder.CreateLoad(result->getAllocatedType(), result));
1926+
1927+
return powFunction;
1928+
}
1929+
1930+
};
18561931

18571932
template <bool Left> class Shift final : public CBSIntrinsic {
18581933
std::string_view getName() override { return Left ? "__cbs_shift_left" : "__cbs_shift_right"; }

0 commit comments

Comments
 (0)