Skip to content

Commit 657ca67

Browse files
authored
Support int calling for cublas (#2685)
* Support int calling for cublas * fix cublasv2
1 parent 06ccc95 commit 657ca67

File tree

4 files changed

+33
-6
lines changed

4 files changed

+33
-6
lines changed

enzyme/Enzyme/ActivityAnalysis.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1362,16 +1362,24 @@ bool ActivityAnalyzer::isConstantValue(TypeResults const &TR, Value *Val) {
13621362
<< " dependant on " << *LoadReval << "\n";
13631363
ReEvaluateValueIfInactiveInst[LoadReval].insert(Val);
13641364
}
1365-
if (StoreReval && EnzymeEnableRecursiveHypotheses)
1365+
if (StoreReval && EnzymeEnableRecursiveHypotheses) {
1366+
if (EnzymePrintActivity)
1367+
llvm::errs() << " global activity of " << *Val
1368+
<< " dependant on " << *StoreReval << "\n";
13661369
ReEvaluateValueIfInactiveInst[StoreReval].insert(Val);
1370+
}
13671371
if (ValLoadReval && EnzymeEnableRecursiveHypotheses) {
13681372
if (EnzymePrintActivity)
13691373
llvm::errs() << " global activity of " << *Val
13701374
<< " dependant on " << *ValLoadReval << "\n";
13711375
ReEvaluateValueIfInactiveValue[ValLoadReval].insert(Val);
13721376
}
1373-
if (ValStoreReval && EnzymeEnableRecursiveHypotheses)
1377+
if (ValStoreReval && EnzymeEnableRecursiveHypotheses) {
1378+
if (EnzymePrintActivity)
1379+
llvm::errs() << " global activity of " << *Val
1380+
<< " dependant on " << *ValStoreReval << "\n";
13741381
ReEvaluateValueIfInactiveValue[ValStoreReval].insert(Val);
1382+
}
13751383
}
13761384
}
13771385
}

enzyme/Enzyme/DiffeGradientUtils.cpp

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -818,14 +818,18 @@ void DiffeGradientUtils::addToInvertedPtrDiffe(Instruction *orig,
818818

819819
bool needsCast = false;
820820
#if LLVM_VERSION_MAJOR < 17
821-
if (origptr->getContext().supportsTypedPointers()) {
821+
if (isa<PointerType>(origptr->getType()) &&
822+
origptr->getContext().supportsTypedPointers()) {
822823
needsCast = origptr->getType()->getPointerElementType() != addingType;
823824
}
824825
#endif
825826

826827
assert(ptr);
827-
if (start != 0 || needsCast) {
828+
if (start != 0 || needsCast || !isa<PointerType>(origptr->getType())) {
828829
auto rule = [&](Value *ptr) {
830+
if (!isa<PointerType>(origptr->getType())) {
831+
ptr = BuilderM.CreateIntToPtr(ptr, getUnqual(addingType));
832+
}
829833
if (start != 0) {
830834
auto i8 = Type::getInt8Ty(ptr->getContext());
831835
ptr = BuilderM.CreatePointerCast(
@@ -845,7 +849,9 @@ void DiffeGradientUtils::addToInvertedPtrDiffe(Instruction *orig,
845849
ptr = applyChainRule(
846850
PointerType::get(
847851
addingType,
848-
cast<PointerType>(origptr->getType())->getAddressSpace()),
852+
isa<PointerType>(origptr->getType())
853+
? cast<PointerType>(origptr->getType())->getAddressSpace()
854+
: 0),
849855
BuilderM, rule, ptr);
850856
}
851857

enzyme/tools/enzyme-tblgen/blas-tblgen.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1473,7 +1473,13 @@ void rev_call_args(bool forward, Twine argName, const TGPattern &pattern,
14731473
// Distinguish later trough byRef if it is cblas (thus has layout)
14741474
os << " if (cblas) " << argName << ".push_back(arg_layout);\n";
14751475
}
1476-
os << " if (cublas) " << argName << ".push_back(arg_handle);\n";
1476+
os << " if (cublas) " << argName << ".push_back(";
1477+
if (!forward)
1478+
os << "lookup(";
1479+
os << "arg_handle";
1480+
if (!forward)
1481+
os << ", Builder2)";
1482+
os << ");\n";
14771483

14781484
for (size_t pos = fncHasLayout ? 1 : 0; pos < numArgs; pos++) {
14791485
os << " for (auto item : ";

enzyme/tools/enzyme-tblgen/blasDeclUpdater.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,13 @@ inline void emit_attributeBLAS(const TGPattern &pattern, raw_ostream &os) {
117117
"llvm::Attribute::get(F->getContext(), llvm::Attribute::ZExt));\n";
118118
}
119119
os << " }\n";
120+
121+
if (has_active_return(name)) {
122+
os << "const bool cublasv2 = blas.prefix == "
123+
"\"cublas\" && llvm::StringRef(blas.suffix).contains(\"v2\");\n";
124+
os << " if (cublasv2) argTys.push_back(getUnqual(fpType));\n";
125+
}
126+
120127
os << " auto nextFT = llvm::FunctionType::get(prevFT->getReturnType(), "
121128
"argTys, false);\n";
122129
os << " if (nextFT != prevFT && F->empty()) {\n";

0 commit comments

Comments
 (0)