Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 96 additions & 53 deletions lib/Dialect/Torch/IR/TorchOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -911,9 +911,15 @@ OpFoldResult AtenToDtypeOp::fold(FoldAdaptor adaptor) {
if (inputType == resType && inputType.hasDtype())
return getOperand(0);

// Fold conversion of splat values.
auto elems = dyn_cast_or_null<DenseElementsAttr>(adaptor.getSelf());
if (!elems || !elems.isSplat())
DenseElementsAttr elems;
if (auto valueTensorLiteralOp =
getOperand(0).getDefiningOp<ValueTensorLiteralOp>()) {
elems =
dyn_cast_or_null<DenseElementsAttr>(valueTensorLiteralOp.getValue());
} else {
elems = dyn_cast_or_null<DenseElementsAttr>(adaptor.getSelf());
}
if (!elems || (!elems.isSplat() && elems.size() > 10))
return {};

auto outVTy = dyn_cast<ValueTensorType>(getType());
Expand All @@ -927,64 +933,101 @@ OpFoldResult AtenToDtypeOp::fold(FoldAdaptor adaptor) {
Type srcEltTy = inputType.getDtype();
Type dstEltTy = outVTy.getDtype();

// Handle integer destination.
if (auto dstI = dyn_cast<IntegerType>(dstEltTy)) {
// any -> bool(i1).
if (dstI.isSignlessInteger(1)) {
bool truthy = false;
if (isa<mlir::FloatType>(srcEltTy)) {
const APFloat &floatVal = elems.getSplatValue<APFloat>();
truthy = !floatVal.isZero();
} else {
const APInt &intVal = elems.getSplatValue<APInt>();
truthy = !intVal.isZero();
auto convertElement = [&](Attribute srcAttr) -> std::optional<Attribute> {
// Handle integer destination.
if (auto dstI = dyn_cast<IntegerType>(dstEltTy)) {
// any -> bool(i1).
if (dstI.isSignlessInteger(1)) {
bool truthy = false;
if (isa<mlir::FloatType>(srcEltTy)) {
const APFloat &floatVal = cast<FloatAttr>(srcAttr).getValue();
truthy = !floatVal.isZero();
} else {
const APInt &intVal = cast<IntegerAttr>(srcAttr).getValue();
truthy = !intVal.isZero();
}
return IntegerAttr::get(dstEltTy, APInt(/*numBits=*/1, truthy));
}
return DenseElementsAttr::get(outShaped, APInt(/*numBits=*/1, truthy));
// float -> intN
if (auto srcF = dyn_cast<mlir::FloatType>(srcEltTy)) {
APSInt result(dstI.getWidth(), /*isUnsigned=*/dstI.isUnsignedInteger());
bool isExact = false;
APFloat f = cast<FloatAttr>(srcAttr).getValue();
APFloat::opStatus st =
f.convertToInteger(result, APFloat::rmTowardZero, &isExact);
if (st == APFloat::opOK || st == APFloat::opInexact)
return IntegerAttr::get(dstEltTy, APInt(result));
return {}; // NaN/Inf/out-of-range: preserve runtime semantics.
}
// intM -> intN
const APInt v = cast<IntegerAttr>(srcAttr).getValue();
auto isUnsigned = cast<IntegerType>(srcEltTy).isUnsignedInteger();
auto isSignless = cast<IntegerType>(srcEltTy).isSignlessInteger();
APInt casted = isUnsigned || isSignless ? v.zextOrTrunc(dstI.getWidth())
: v.sextOrTrunc(dstI.getWidth());
return IntegerAttr::get(dstEltTy, casted);
}
// float -> intN
if (auto srcF = dyn_cast<mlir::FloatType>(srcEltTy)) {
APSInt result(dstI.getWidth(), /*isUnsigned=*/dstI.isUnsignedInteger());
bool isExact = false;
APFloat f = elems.getSplatValue<APFloat>();

// Handle float destination.
if (auto dstF = dyn_cast<mlir::FloatType>(dstEltTy)) {
const llvm::fltSemantics &dstSem = dstF.getFloatSemantics();

// int -> float
if (auto srcI = dyn_cast<IntegerType>(srcEltTy)) {
APFloat f(dstSem);
APFloat::opStatus st = f.convertFromAPInt(
cast<IntegerAttr>(srcAttr).getValue(),
/*isSigned=*/!srcI.isUnsignedInteger() && !srcI.isSignlessInteger(),
APFloat::rmNearestTiesToEven);
if (st == APFloat::opOK || st == APFloat::opInexact)
return FloatAttr::get(dstF, f);
return {};
}

// floatX -> floatY
APFloat f = cast<FloatAttr>(srcAttr).getValue();
bool losesInfo = false;
APFloat::opStatus st =
f.convertToInteger(result, APFloat::rmTowardZero, &isExact);
if (st == APFloat::opOK || st == APFloat::opInexact)
return DenseElementsAttr::get(outShaped, APInt(result));
return {}; // NaN/Inf/out-of-range: preserve runtime semantics.
}
// intM -> intN
const APInt &v = elems.getSplatValue<APInt>();
auto isUnsigned = cast<IntegerType>(srcEltTy).isUnsignedInteger();
auto isSignless = cast<IntegerType>(srcEltTy).isSignlessInteger();
APInt casted = isUnsigned || isSignless ? v.zextOrTrunc(dstI.getWidth())
: v.sextOrTrunc(dstI.getWidth());
return DenseElementsAttr::get(outShaped, casted);
}

// Handle float destination.
if (auto dstF = dyn_cast<mlir::FloatType>(dstEltTy)) {
const llvm::fltSemantics &dstSem = dstF.getFloatSemantics();

// int -> float
if (auto srcI = dyn_cast<IntegerType>(srcEltTy)) {
APFloat f(dstSem);
APFloat::opStatus st = f.convertFromAPInt(
elems.getSplatValue<APInt>(),
/*isSigned=*/!srcI.isUnsignedInteger() && !srcI.isSignlessInteger(),
APFloat::rmNearestTiesToEven);
f.convert(dstSem, APFloat::rmNearestTiesToEven, &losesInfo);
if (st == APFloat::opOK || st == APFloat::opInexact)
return DenseElementsAttr::get(outShaped, f);
return FloatAttr::get(dstF, f);
return {};
}

// floatX -> floatY
APFloat f = elems.getSplatValue<APFloat>();
bool losesInfo = false;
APFloat::opStatus st =
f.convert(dstSem, APFloat::rmNearestTiesToEven, &losesInfo);
if (st == APFloat::opOK || st == APFloat::opInexact)
return DenseElementsAttr::get(outShaped, f);
return {};
};

if (elems.isSplat()) {
Attribute singleElem = elems.getSplatValue<Attribute>();
if (auto converted = convertElement(singleElem)) {
return DenseElementsAttr::get(outShaped, *converted);
}
return {};
}

SmallVector<Attribute> converted;
converted.reserve(elems.getNumElements());

if (isa<mlir::FloatType>(srcEltTy)) {
for (const APFloat &v : elems.getValues<APFloat>()) {
if (auto convertedNum = convertElement(FloatAttr::get(srcEltTy, v))) {
converted.push_back(*convertedNum);
} else {
return {};
}
}
return DenseElementsAttr::get(outShaped, converted);
}

if (isa<IntegerType>(srcEltTy)) {
for (const APInt &v : elems.getValues<APInt>()) {
if (auto convertedNum = convertElement(IntegerAttr::get(srcEltTy, v))) {
converted.push_back(*convertedNum);
} else {
return {};
}
}
return DenseElementsAttr::get(outShaped, converted);
}

return {};
Expand Down
8 changes: 1 addition & 7 deletions test/Dialect/Torch/scalarize-shapes.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -87,13 +87,7 @@ func.func @cast_int_float(%arg0 : !torch.vtensor<[5,?,?],f32>) -> !torch.vtensor

// CHECK-LABEL: @cast_int_float_static
func.func @cast_int_float_static(%arg0 : !torch.vtensor<[5,?,?],f32>) -> !torch.vtensor<[3],f32> {
// CHECK: %[[FLOAT1:.*]] = torch.constant.float 1.000000e+00
// CHECK: %[[FLOAT2:.*]] = torch.constant.float 2.000000e+00
// CHECK: %[[FLOAT3:.*]] = torch.constant.float 3.000000e+00
// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[FLOAT1:.*]], %[[FLOAT2:.*]], %[[FLOAT3:.*]] : (!torch.float, !torch.float, !torch.float) -> !torch.list<float>
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: %[[TENSOR:.*]] = torch.aten.tensor %[[LIST]], %[[NONE]], %[[NONE]], %[[FALSE]] : !torch.list<float>, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[3],f32>
// CHECK: %[[TENSOR:.*]] = torch.vtensor.literal(dense<[1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<3xf32>) : !torch.vtensor<[3],f32>
// CHECK: return %[[TENSOR]] : !torch.vtensor<[3],f32>
%int6 = torch.constant.int 6
%false = torch.constant.bool false
Expand Down