Skip to content

Commit ba00913

Browse files
fix types after llvm/llvm-project#123026
1 parent a6ae057 commit ba00913

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

lib/Conversion/TorchToLinalg/Uncategorized.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -549,7 +549,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
549549
}
550550
if (isa<AtenLogicalOrOp, AtenLogicalAndOp, AtenLogicalXorOp>(op)) {
551551
MLIRContext *context = op->getContext();
552-
Type floatDtype = mlir::FloatType::getF64(context);
552+
Type floatDtype = mlir::Float64Type::get(context);
553553
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], floatDtype);
554554
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], floatDtype);
555555
Value zero =
@@ -569,7 +569,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
569569
}
570570
if (isa<AtenLogicalNotOp>(op)) {
571571
MLIRContext *context = op->getContext();
572-
Type floatDtype = mlir::FloatType::getF64(context);
572+
Type floatDtype = mlir::Float64Type::get(context);
573573
Value self = convertScalarToDtype(b, loc, payloadArgs[0], floatDtype);
574574
Value zero =
575575
b.create<arith::ConstantOp>(loc, b.getFloatAttr(floatDtype, 0));
@@ -1028,7 +1028,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
10281028
Type powType = dtype;
10291029
if (payloadArgs[0].getType().isInteger() ||
10301030
payloadArgs[1].getType().isInteger())
1031-
powType = mlir::FloatType::getF64(op->getContext());
1031+
powType = mlir::Float64Type::get(op->getContext());
10321032
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], powType);
10331033
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], powType);
10341034
auto powOp = b.create<math::PowFOp>(loc, lhs, rhs);

lib/Dialect/Torch/Utils/Utils.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -152,9 +152,9 @@ Torch::getTypeForScalarType(MLIRContext *context,
152152
case torch_upstream::ScalarType::Bool:
153153
return IntegerType::get(context, 1);
154154
case torch_upstream::ScalarType::BFloat16:
155-
return mlir::FloatType::getBF16(context);
155+
return mlir::BFloat16Type::get(context);
156156
case torch_upstream::ScalarType::Half:
157-
return mlir::FloatType::getF16(context);
157+
return mlir::Float16Type::get(context);
158158
case torch_upstream::ScalarType::Byte:
159159
return mlir::IntegerType::get(context, 8, mlir::IntegerType::Unsigned);
160160
case torch_upstream::ScalarType::Char:

0 commit comments

Comments
 (0)