Skip to content

Commit bf11541

Browse files
authored
Adapt to upstream (#2228)
1 parent c4f953f commit bf11541

File tree

2 files changed

+11
-14
lines changed

2 files changed

+11
-14
lines changed

enzyme/Enzyme/MLIR/Implementations/BuiltinAutoDiffTypeInterfaceImpl.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -41,12 +41,12 @@ static mlir::Type batchType(mlir::Type type, int64_t width) {
4141
return RankedTensorType::get({width}, type);
4242
}
4343

44-
class FloatTypeInterface
45-
: public AutoDiffTypeInterface::ExternalModel<FloatTypeInterface,
46-
FloatType> {
44+
template <typename ConcreteType>
45+
class FloatTypeInterface : public AutoDiffTypeInterface::ExternalModel<
46+
FloatTypeInterface<ConcreteType>, ConcreteType> {
4747
public:
4848
Value createNullValue(Type self, OpBuilder &builder, Location loc) const {
49-
auto fltType = self.cast<FloatType>();
49+
auto fltType = self.cast<ConcreteType>();
5050
return builder.create<arith::ConstantFloatOp>(
5151
loc, APFloat(fltType.getFloatSemantics(), 0), fltType);
5252
}
@@ -200,10 +200,10 @@ class ComplexTypeInterface
200200
void mlir::enzyme::registerBuiltinDialectAutoDiffInterface(
201201
DialectRegistry &registry) {
202202
registry.addExtension(+[](MLIRContext *context, BuiltinDialect *) {
203-
BFloat16Type::attachInterface<FloatTypeInterface>(*context);
204-
Float16Type::attachInterface<FloatTypeInterface>(*context);
205-
Float32Type::attachInterface<FloatTypeInterface>(*context);
206-
Float64Type::attachInterface<FloatTypeInterface>(*context);
203+
BFloat16Type::attachInterface<FloatTypeInterface<BFloat16Type>>(*context);
204+
Float16Type::attachInterface<FloatTypeInterface<Float16Type>>(*context);
205+
Float32Type::attachInterface<FloatTypeInterface<Float32Type>>(*context);
206+
Float64Type::attachInterface<FloatTypeInterface<Float64Type>>(*context);
207207
IntegerType::attachInterface<IntegerTypeInterface<IntegerType>>(*context);
208208
IndexType::attachInterface<IntegerTypeInterface<IndexType>>(*context);
209209
UnrankedTensorType::attachInterface<TensorTypeInterface>(*context);

enzyme/Enzyme/MustExitScalarEvolution.cpp

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -340,12 +340,9 @@ ScalarEvolution::ExitLimit MustExitScalarEvolution::computeExitLimitFromICmp(
340340
const Loop *L, ICmpInst *ExitCond, bool ExitIfTrue, bool ControlsExit,
341341
bool AllowPredicates) {
342342
// If the condition was exit on true, convert the condition to exit on false
343-
ICmpInst::Predicate Pred;
344-
if (!ExitIfTrue)
345-
Pred = ExitCond->getPredicate();
346-
else
347-
Pred = ExitCond->getInversePredicate();
348-
const ICmpInst::Predicate OriginalPred = Pred;
343+
auto Pred = (!ExitIfTrue) ? ExitCond->getPredicate()
344+
: ExitCond->getInversePredicate();
345+
const auto OriginalPred = Pred;
349346

350347
#if LLVM_VERSION_MAJOR < 14
351348
// Handle common loops like: for (X = "string"; *X; ++X)

0 commit comments

Comments
 (0)