diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index b58c6f211b03..2126a2ea0c97 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -911,9 +911,15 @@ OpFoldResult AtenToDtypeOp::fold(FoldAdaptor adaptor) { if (inputType == resType && inputType.hasDtype()) return getOperand(0); - // Fold conversion of splat values. - auto elems = dyn_cast_or_null(adaptor.getSelf()); - if (!elems || !elems.isSplat()) + DenseElementsAttr elems; + if (auto valueTensorLiteralOp = + getOperand(0).getDefiningOp()) { + elems = + dyn_cast_or_null(valueTensorLiteralOp.getValue()); + } else { + elems = dyn_cast_or_null(adaptor.getSelf()); + } + if (!elems || (!elems.isSplat() && elems.size() > 10)) return {}; auto outVTy = dyn_cast(getType()); @@ -927,64 +933,101 @@ OpFoldResult AtenToDtypeOp::fold(FoldAdaptor adaptor) { Type srcEltTy = inputType.getDtype(); Type dstEltTy = outVTy.getDtype(); - // Handle integer destination. - if (auto dstI = dyn_cast(dstEltTy)) { - // any -> bool(i1). - if (dstI.isSignlessInteger(1)) { - bool truthy = false; - if (isa(srcEltTy)) { - const APFloat &floatVal = elems.getSplatValue(); - truthy = !floatVal.isZero(); - } else { - const APInt &intVal = elems.getSplatValue(); - truthy = !intVal.isZero(); + auto convertElement = [&](Attribute srcAttr) -> std::optional { + // Handle integer destination. + if (auto dstI = dyn_cast(dstEltTy)) { + // any -> bool(i1). + if (dstI.isSignlessInteger(1)) { + bool truthy = false; + if (isa(srcEltTy)) { + const APFloat &floatVal = cast(srcAttr).getValue(); + truthy = !floatVal.isZero(); + } else { + const APInt &intVal = cast(srcAttr).getValue(); + truthy = !intVal.isZero(); + } + return IntegerAttr::get(dstEltTy, APInt(/*numBits=*/1, truthy)); } - return DenseElementsAttr::get(outShaped, APInt(/*numBits=*/1, truthy)); + // float -> intN + if (auto srcF = dyn_cast(srcEltTy)) { + APSInt result(dstI.getWidth(), /*isUnsigned=*/dstI.isUnsignedInteger()); + bool isExact = false; + APFloat f = cast(srcAttr).getValue(); + APFloat::opStatus st = + f.convertToInteger(result, APFloat::rmTowardZero, &isExact); + if (st == APFloat::opOK || st == APFloat::opInexact) + return IntegerAttr::get(dstEltTy, APInt(result)); + return {}; // NaN/Inf/out-of-range: preserve runtime semantics. + } + // intM -> intN + const APInt v = cast(srcAttr).getValue(); + auto isUnsigned = cast(srcEltTy).isUnsignedInteger(); + auto isSignless = cast(srcEltTy).isSignlessInteger(); + APInt casted = isUnsigned || isSignless ? v.zextOrTrunc(dstI.getWidth()) + : v.sextOrTrunc(dstI.getWidth()); + return IntegerAttr::get(dstEltTy, casted); } - // float -> intN - if (auto srcF = dyn_cast(srcEltTy)) { - APSInt result(dstI.getWidth(), /*isUnsigned=*/dstI.isUnsignedInteger()); - bool isExact = false; - APFloat f = elems.getSplatValue(); + + // Handle float destination. + if (auto dstF = dyn_cast(dstEltTy)) { + const llvm::fltSemantics &dstSem = dstF.getFloatSemantics(); + + // int -> float + if (auto srcI = dyn_cast(srcEltTy)) { + APFloat f(dstSem); + APFloat::opStatus st = f.convertFromAPInt( + cast(srcAttr).getValue(), + /*isSigned=*/!srcI.isUnsignedInteger() && !srcI.isSignlessInteger(), + APFloat::rmNearestTiesToEven); + if (st == APFloat::opOK || st == APFloat::opInexact) + return FloatAttr::get(dstF, f); + return {}; + } + + // floatX -> floatY + APFloat f = cast(srcAttr).getValue(); + bool losesInfo = false; APFloat::opStatus st = - f.convertToInteger(result, APFloat::rmTowardZero, &isExact); - if (st == APFloat::opOK || st == APFloat::opInexact) - return DenseElementsAttr::get(outShaped, APInt(result)); - return {}; // NaN/Inf/out-of-range: preserve runtime semantics. - } - // intM -> intN - const APInt &v = elems.getSplatValue(); - auto isUnsigned = cast(srcEltTy).isUnsignedInteger(); - auto isSignless = cast(srcEltTy).isSignlessInteger(); - APInt casted = isUnsigned || isSignless ? v.zextOrTrunc(dstI.getWidth()) - : v.sextOrTrunc(dstI.getWidth()); - return DenseElementsAttr::get(outShaped, casted); - } - - // Handle float destination. - if (auto dstF = dyn_cast(dstEltTy)) { - const llvm::fltSemantics &dstSem = dstF.getFloatSemantics(); - - // int -> float - if (auto srcI = dyn_cast(srcEltTy)) { - APFloat f(dstSem); - APFloat::opStatus st = f.convertFromAPInt( - elems.getSplatValue(), - /*isSigned=*/!srcI.isUnsignedInteger() && !srcI.isSignlessInteger(), - APFloat::rmNearestTiesToEven); + f.convert(dstSem, APFloat::rmNearestTiesToEven, &losesInfo); if (st == APFloat::opOK || st == APFloat::opInexact) - return DenseElementsAttr::get(outShaped, f); + return FloatAttr::get(dstF, f); return {}; } - // floatX -> floatY - APFloat f = elems.getSplatValue(); - bool losesInfo = false; - APFloat::opStatus st = - f.convert(dstSem, APFloat::rmNearestTiesToEven, &losesInfo); - if (st == APFloat::opOK || st == APFloat::opInexact) - return DenseElementsAttr::get(outShaped, f); return {}; + }; + + if (elems.isSplat()) { + Attribute singleElem = elems.getSplatValue(); + if (auto converted = convertElement(singleElem)) { + return DenseElementsAttr::get(outShaped, *converted); + } + return {}; + } + + SmallVector converted; + converted.reserve(elems.getNumElements()); + + if (isa(srcEltTy)) { + for (const APFloat &v : elems.getValues()) { + if (auto convertedNum = convertElement(FloatAttr::get(srcEltTy, v))) { + converted.push_back(*convertedNum); + } else { + return {}; + } + } + return DenseElementsAttr::get(outShaped, converted); + } + + if (isa(srcEltTy)) { + for (const APInt &v : elems.getValues()) { + if (auto convertedNum = convertElement(IntegerAttr::get(srcEltTy, v))) { + converted.push_back(*convertedNum); + } else { + return {}; + } + } + return DenseElementsAttr::get(outShaped, converted); } return {}; diff --git a/test/Dialect/Torch/scalarize-shapes.mlir b/test/Dialect/Torch/scalarize-shapes.mlir index 00975a2405be..1ebb26bdf307 100644 --- a/test/Dialect/Torch/scalarize-shapes.mlir +++ b/test/Dialect/Torch/scalarize-shapes.mlir @@ -87,13 +87,7 @@ func.func @cast_int_float(%arg0 : !torch.vtensor<[5,?,?],f32>) -> !torch.vtensor // CHECK-LABEL: @cast_int_float_static func.func @cast_int_float_static(%arg0 : !torch.vtensor<[5,?,?],f32>) -> !torch.vtensor<[3],f32> { - // CHECK: %[[FLOAT1:.*]] = torch.constant.float 1.000000e+00 - // CHECK: %[[FLOAT2:.*]] = torch.constant.float 2.000000e+00 - // CHECK: %[[FLOAT3:.*]] = torch.constant.float 3.000000e+00 - // CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[FLOAT1:.*]], %[[FLOAT2:.*]], %[[FLOAT3:.*]] : (!torch.float, !torch.float, !torch.float) -> !torch.list - // CHECK: %[[NONE:.*]] = torch.constant.none - // CHECK: %[[FALSE:.*]] = torch.constant.bool false - // CHECK: %[[TENSOR:.*]] = torch.aten.tensor %[[LIST]], %[[NONE]], %[[NONE]], %[[FALSE]] : !torch.list, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[3],f32> + // CHECK: %[[TENSOR:.*]] = torch.vtensor.literal(dense<[1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<3xf32>) : !torch.vtensor<[3],f32> // CHECK: return %[[TENSOR]] : !torch.vtensor<[3],f32> %int6 = torch.constant.int 6 %false = torch.constant.bool false