Skip to content

Commit 889a836

Browse files
aldesilvzjgarvey
andauthored
OnnxToTorch bicubic interpolation (#3802)
(nod-ai/SHARK-TestSuite#391) Repro (using SHARK TestSuite): 1. `python run.py --torchtolinalg -m cl-onnx-iree -t cubic_test` --------- Co-authored-by: zjgarvey <[email protected]>
1 parent 17c1985 commit 889a836

File tree

3 files changed

+292
-31
lines changed

3 files changed

+292
-31
lines changed

lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2922,7 +2922,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
29222922
llvm::SmallVector<Value> operands;
29232923
std::string mode, nearest_mode, coordTfMode;
29242924
int64_t antialias, exclude_outside;
2925-
float extrapolation_value;
2925+
float extrapolation_value, cubic_coeff_a;
29262926
Value noneVal = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
29272927

29282928
if (auto attr = binder.op->getAttr("torch.onnx.axes")) {
@@ -2947,7 +2947,8 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
29472947
binder.f32FloatAttr(extrapolation_value, "extrapolation_value",
29482948
0.0) ||
29492949
binder.customOpNameStringAttr(nearest_mode, "nearest_mode",
2950-
"round_prefer_floor"))
2950+
"round_prefer_floor") ||
2951+
binder.f32FloatAttr(cubic_coeff_a, "cubic_coeff_a", -0.75))
29512952
return failure();
29522953
if (antialias != 0) {
29532954
return rewriter.notifyMatchFailure(
@@ -2976,6 +2977,11 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
29762977
"except asymmetric and half_pixel");
29772978
}
29782979

2980+
if (mode == "cubic" && cubic_coeff_a != -0.75) {
2981+
return rewriter.notifyMatchFailure(
2982+
binder.op, "unimplemented: cubic coeff must be -0.75");
2983+
}
2984+
29792985
unsigned rank = dyn_cast<Torch::ValueTensorType>(operands[0].getType())
29802986
.getSizes()
29812987
.size();
@@ -2991,8 +2997,11 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
29912997
Value alignCorners =
29922998
coordTfMode == "align_corners" ? cstTrue : cstFalse;
29932999
if (mode == "cubic") {
2994-
return rewriter.notifyMatchFailure(binder.op,
2995-
"unimplemented: bicubic mode");
3000+
std::string modeStr = "cubic";
3001+
if (coordTfMode != "half_pixel")
3002+
modeStr = modeStr + "_" + coordTfMode;
3003+
modeStrValue =
3004+
rewriter.create<Torch::ConstantStrOp>(binder.getLoc(), modeStr);
29963005
}
29973006
// supported modes:
29983007
// bilinear (half_pixel), bilinear with align_corners,

lib/Conversion/TorchToLinalg/Uncategorized.cpp

Lines changed: 230 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -2683,7 +2683,7 @@ class ConvertAtenGridSamplerOp : public OpConversionPattern<AtenGridSamplerOp> {
26832683
};
26842684
} // namespace
26852685

2686-
static Value NearestInterpolate(OpBuilder &b, Location loc,
2686+
static Value nearestInterpolate(OpBuilder &b, Location loc,
26872687
SmallVector<Value> outputSizes, Value input,
26882688
SmallVector<Value> inputSizes,
26892689
SmallVector<Value> scaleValues,
@@ -2771,12 +2771,12 @@ static Value NearestInterpolate(OpBuilder &b, Location loc,
27712771
return retVal;
27722772
}
27732773

2774-
static Value BilinearInterpolate(OpBuilder &b,
2775-
Aten__InterpolateSizeListScaleListOp op,
2776-
Location loc, SmallVector<Value> outputSizes,
2777-
Value input, SmallVector<Value> inputSizes,
2778-
SmallVector<Value> scaleValues,
2779-
std::string coordStr) {
2774+
static SmallVector<Value> coordinateTransform(
2775+
OpBuilder &b, Aten__InterpolateSizeListScaleListOp op, Location loc,
2776+
SmallVector<Value> outputSizes, Value input, SmallVector<Value> inputSizes,
2777+
SmallVector<Value> scaleValues, std::string coordStr, bool alignCornersBool,
2778+
SmallVector<Value> indices, bool clip) {
2779+
27802780
unsigned dimOffset = 2;
27812781
auto inputType = cast<RankedTensorType>(input.getType());
27822782
auto inputRank = inputType.getRank();
@@ -2785,15 +2785,7 @@ static Value BilinearInterpolate(OpBuilder &b,
27852785
Value cstHalf = b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(0.5));
27862786
Value zero = b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(0.0));
27872787

2788-
bool alignCornersBool;
2789-
matchPattern(op.getAlignCorners(), m_TorchConstantBool(&alignCornersBool));
2790-
2791-
SmallVector<Value> indices;
2792-
for (unsigned i = 0; i < inputRank; i++) {
2793-
indices.push_back(b.create<linalg::IndexOp>(loc, i));
2794-
}
2795-
2796-
SmallVector<Value> proj, projEps, high, low, highFP, lowFP;
2788+
SmallVector<Value> proj;
27972789
for (unsigned i = 0; i < inputRank - dimOffset; i++) {
27982790
// length_original
27992791
Value inputFP =
@@ -2856,13 +2848,50 @@ static Value BilinearInterpolate(OpBuilder &b,
28562848
outputSizeFP, cstOneFloat);
28572849
preClip = b.create<arith::SelectOp>(loc, cmp, zero, preClip);
28582850
}
2859-
// preClip is the fp position inside the input image to extract from.
2860-
// clip to [0,inf)
2861-
Value max = b.create<arith::MaximumFOp>(loc, preClip, zero);
2851+
if (clip) {
2852+
// preClip is the fp position inside the input image to extract from.
2853+
// clip to [0,inf)
2854+
Value max = b.create<arith::MaximumFOp>(loc, preClip, zero);
2855+
Value inputSubOne = b.create<arith::SubFOp>(loc, inputFP, cstOneFloat);
2856+
// clip to [0,length_original - 1].
2857+
// proj is properly within the input image.
2858+
proj.push_back(b.create<arith::MinimumFOp>(loc, max, inputSubOne));
2859+
} else {
2860+
proj.push_back(preClip);
2861+
}
2862+
}
2863+
return proj;
2864+
}
2865+
2866+
static Value bilinearInterpolate(OpBuilder &b,
2867+
Aten__InterpolateSizeListScaleListOp op,
2868+
Location loc, SmallVector<Value> outputSizes,
2869+
Value input, SmallVector<Value> inputSizes,
2870+
SmallVector<Value> scaleValues,
2871+
std::string coordStr) {
2872+
unsigned dimOffset = 2;
2873+
auto inputType = cast<RankedTensorType>(input.getType());
2874+
auto inputRank = inputType.getRank();
2875+
2876+
Value cstOneFloat = b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(1.0));
2877+
2878+
bool alignCornersBool;
2879+
matchPattern(op.getAlignCorners(), m_TorchConstantBool(&alignCornersBool));
2880+
2881+
SmallVector<Value> indices;
2882+
for (unsigned i = 0; i < inputRank; i++) {
2883+
indices.push_back(b.create<linalg::IndexOp>(loc, i));
2884+
}
2885+
2886+
SmallVector<Value> proj, high, low, highFP, lowFP;
2887+
proj = coordinateTransform(b, op, loc, outputSizes, input, inputSizes,
2888+
scaleValues, coordStr, alignCornersBool, indices,
2889+
true);
2890+
for (unsigned i = 0; i < inputRank - dimOffset; i++) {
2891+
// length_original
2892+
Value inputFP =
2893+
b.create<arith::SIToFPOp>(loc, b.getF32Type(), inputSizes[i]);
28622894
Value inputSubOne = b.create<arith::SubFOp>(loc, inputFP, cstOneFloat);
2863-
// clip to [0,length_original - 1].
2864-
// proj is properly within the input image.
2865-
proj.push_back(b.create<arith::MinimumFOp>(loc, max, inputSubOne));
28662895

28672896
// for bilinear interpolation, we look for the nearest indices below and
28682897
// above proj
@@ -2926,6 +2955,176 @@ static Value BilinearInterpolate(OpBuilder &b,
29262955
return b.create<arith::AddFOp>(loc, left, right);
29272956
}
29282957

2958+
static Value bicubicInterpolate(OpBuilder &b,
2959+
Aten__InterpolateSizeListScaleListOp op,
2960+
Location loc, SmallVector<Value> outputSizes,
2961+
Value input, SmallVector<Value> inputSizes,
2962+
SmallVector<Value> scaleValues,
2963+
std::string coordStr) {
2964+
unsigned dimOffset = 2;
2965+
auto inputType = cast<RankedTensorType>(input.getType());
2966+
auto inputRank = inputType.getRank();
2967+
2968+
Value inputFPH =
2969+
b.create<arith::SIToFPOp>(loc, b.getF32Type(), inputSizes[0]);
2970+
Value inputFPW =
2971+
b.create<arith::SIToFPOp>(loc, b.getF32Type(), inputSizes[1]);
2972+
2973+
Value a = b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(-0.75));
2974+
Value zero = b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(0.0));
2975+
Value cstOneFloat = b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(1.0));
2976+
Value cstTwoFloat = b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(2.0));
2977+
Value cstThreeFloat =
2978+
b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(3.0));
2979+
Value cstFourFloat = b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(4.0));
2980+
Value cstFiveFloat = b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(5.0));
2981+
Value cstEightFloat =
2982+
b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(8.0));
2983+
2984+
// (a+2)|x|^3 - (a+3)|x|^2 + 1 for xDistance (|x| <= 1)
2985+
auto WeightLessThanEqualOne = [&](Value xDistance) -> Value {
2986+
Value xDistanceSquared = b.create<arith::MulFOp>(loc, xDistance, xDistance);
2987+
Value xDistanceCubed =
2988+
b.create<arith::MulFOp>(loc, xDistanceSquared, xDistance);
2989+
2990+
Value lessEqualOne = b.create<arith::AddFOp>(loc, a, cstTwoFloat);
2991+
lessEqualOne = b.create<arith::MulFOp>(loc, xDistanceCubed, lessEqualOne);
2992+
Value aPlusThree = b.create<arith::AddFOp>(loc, a, cstThreeFloat);
2993+
aPlusThree = b.create<arith::MulFOp>(loc, xDistanceSquared, aPlusThree);
2994+
lessEqualOne = b.create<arith::SubFOp>(loc, lessEqualOne, aPlusThree);
2995+
lessEqualOne = b.create<arith::AddFOp>(loc, lessEqualOne, cstOneFloat);
2996+
2997+
return lessEqualOne;
2998+
};
2999+
3000+
// a|x|^3 - 5a|x|^2 + 8a|x| - 4a for xDistance (1 < |x| < 2)
3001+
auto WeightLessThanTwo = [&](Value xDistance) -> Value {
3002+
Value xDistanceSquared = b.create<arith::MulFOp>(loc, xDistance, xDistance);
3003+
Value xDistanceCubed =
3004+
b.create<arith::MulFOp>(loc, xDistanceSquared, xDistance);
3005+
// a|x|^3
3006+
Value lessThanTwo = b.create<arith::MulFOp>(loc, xDistanceCubed, a);
3007+
3008+
Value fiveA = b.create<arith::MulFOp>(loc, xDistanceSquared, a);
3009+
fiveA = b.create<arith::MulFOp>(loc, fiveA, cstFiveFloat);
3010+
// a|x|^3 - 5a|x|^2
3011+
lessThanTwo = b.create<arith::SubFOp>(loc, lessThanTwo, fiveA);
3012+
3013+
Value eightA = b.create<arith::MulFOp>(loc, a, xDistance);
3014+
eightA = b.create<arith::MulFOp>(loc, eightA, cstEightFloat);
3015+
// a|x|^3 - 5a|x|^2 + 8a|x|
3016+
lessThanTwo = b.create<arith::AddFOp>(loc, eightA, lessThanTwo);
3017+
3018+
Value fourA = b.create<arith::MulFOp>(loc, a, cstFourFloat);
3019+
// a|x|^3 - 5a|x|^2 + 8a|x| - 4a
3020+
lessThanTwo = b.create<arith::SubFOp>(loc, lessThanTwo, fourA);
3021+
return lessThanTwo;
3022+
};
3023+
3024+
bool alignCornersBool;
3025+
matchPattern(op.getAlignCorners(), m_TorchConstantBool(&alignCornersBool));
3026+
3027+
SmallVector<Value> indices;
3028+
for (unsigned i = 0; i < inputRank; i++) {
3029+
indices.push_back(b.create<linalg::IndexOp>(loc, i));
3030+
}
3031+
3032+
SmallVector<Value> proj;
3033+
3034+
proj = coordinateTransform(b, op, loc, outputSizes, input, inputSizes,
3035+
scaleValues, coordStr, alignCornersBool, indices,
3036+
false);
3037+
3038+
// get the nearest neighbors of proj
3039+
Value x1 = b.create<math::CeilOp>(loc, proj[1]);
3040+
Value x_1 = b.create<arith::SubFOp>(loc, x1, cstOneFloat);
3041+
Value x_2 = b.create<arith::SubFOp>(loc, x_1, cstOneFloat);
3042+
Value x2 = b.create<arith::AddFOp>(loc, x1, cstOneFloat);
3043+
3044+
Value y1 = b.create<math::CeilOp>(loc, proj[0]);
3045+
Value y_1 = b.create<arith::SubFOp>(loc, y1, cstOneFloat);
3046+
Value y_2 = b.create<arith::SubFOp>(loc, y_1, cstOneFloat);
3047+
Value y2 = b.create<arith::AddFOp>(loc, y1, cstOneFloat);
3048+
3049+
// calculate the distance of nearest neighbors x and y to proj
3050+
Value y2Distance = b.create<arith::SubFOp>(loc, proj[0], y2);
3051+
y2Distance = b.create<math::AbsFOp>(loc, y2Distance);
3052+
Value y1Distance = b.create<arith::SubFOp>(loc, proj[0], y1);
3053+
y1Distance = b.create<math::AbsFOp>(loc, y1Distance);
3054+
Value y_1Distance = b.create<arith::SubFOp>(loc, proj[0], y_1);
3055+
y_1Distance = b.create<math::AbsFOp>(loc, y_1Distance);
3056+
Value y_2Distance = b.create<arith::SubFOp>(loc, proj[0], y_2);
3057+
y_2Distance = b.create<math::AbsFOp>(loc, y_2Distance);
3058+
3059+
Value x2Distance = b.create<arith::SubFOp>(loc, proj[1], x2);
3060+
x2Distance = b.create<math::AbsFOp>(loc, x2Distance);
3061+
Value x1Distance = b.create<arith::SubFOp>(loc, proj[1], x1);
3062+
x1Distance = b.create<math::AbsFOp>(loc, x1Distance);
3063+
Value x_1Distance = b.create<arith::SubFOp>(loc, proj[1], x_1);
3064+
x_1Distance = b.create<math::AbsFOp>(loc, x_1Distance);
3065+
Value x_2Distance = b.create<arith::SubFOp>(loc, proj[1], x_2);
3066+
x_2Distance = b.create<math::AbsFOp>(loc, x_2Distance);
3067+
3068+
SmallVector<Value> y{y_2, y_1, y1, y2};
3069+
SmallVector<Value> x{x_2, x_1, x1, x2};
3070+
3071+
SmallVector<Value> wys{
3072+
WeightLessThanTwo(y_2Distance), WeightLessThanEqualOne(y_1Distance),
3073+
WeightLessThanEqualOne(y1Distance), WeightLessThanTwo(y2Distance)};
3074+
SmallVector<Value> wxs{
3075+
WeightLessThanTwo(x_2Distance), WeightLessThanEqualOne(x_1Distance),
3076+
WeightLessThanEqualOne(x1Distance), WeightLessThanTwo(x2Distance)};
3077+
3078+
// clip the nearest neighbors points to inside the original image
3079+
for (int k = 0; k < 4; k++) {
3080+
Value yClipped = b.create<arith::MaximumFOp>(loc, y[k], zero);
3081+
Value inputHSubOne = b.create<arith::SubFOp>(loc, inputFPH, cstOneFloat);
3082+
yClipped = b.create<arith::MinimumFOp>(loc, yClipped, inputHSubOne);
3083+
Value yInt = b.create<arith::FPToSIOp>(loc, b.getI64Type(), yClipped);
3084+
y[k] = b.create<arith::IndexCastOp>(loc, b.getIndexType(), yInt);
3085+
3086+
Value xClipped = b.create<arith::MaximumFOp>(loc, x[k], zero);
3087+
Value inputWSubOne = b.create<arith::SubFOp>(loc, inputFPW, cstOneFloat);
3088+
xClipped = b.create<arith::MinimumFOp>(loc, xClipped, inputWSubOne);
3089+
Value xInt = b.create<arith::FPToSIOp>(loc, b.getI64Type(), xClipped);
3090+
x[k] = b.create<arith::IndexCastOp>(loc, b.getIndexType(), xInt);
3091+
}
3092+
// 1. Compute x_original and y_original (proj)
3093+
// 2. Compute nearest x and y neighbors
3094+
// 3. Compute Wx Wy
3095+
// 4. Extract inputs at nearest neighbors (inputExtracts)
3096+
// 5. Compute weighted sum (yield this)
3097+
3098+
// 4 nearest x neighbors : [x_2, x_1, x1, x2] of x_original
3099+
// 4 nearest y neighbors : [y_2, y_1, y1, y2] of y_original
3100+
// Sum_x is over 4 nearest x neighbors (similar for Sum_y)
3101+
// f(x_original, y_original) = Sum_y Sum_x W(x_original - x)*input[x,y]
3102+
// * W(y_original - y)
3103+
Value fxy = zero;
3104+
3105+
for (int j = 0; j < 4; j++) {
3106+
Value wy = wys[j];
3107+
Value xInterpy = zero;
3108+
3109+
indices[dimOffset] = y[j];
3110+
3111+
for (int i = 0; i < 4; i++) {
3112+
Value wx = wxs[i];
3113+
3114+
indices[dimOffset + 1] = x[i];
3115+
3116+
Value p = b.create<tensor::ExtractOp>(loc, input, indices);
3117+
3118+
Value wxp = b.create<arith::MulFOp>(loc, wx, p);
3119+
xInterpy = b.create<arith::AddFOp>(loc, xInterpy, wxp);
3120+
}
3121+
Value wyXInterpy = b.create<arith::MulFOp>(loc, wy, xInterpy);
3122+
fxy = b.create<arith::AddFOp>(loc, fxy, wyXInterpy);
3123+
}
3124+
3125+
return fxy;
3126+
}
3127+
29293128
namespace {
29303129
class ConvertInterpolateOp
29313130
: public OpConversionPattern<Aten__InterpolateSizeListScaleListOp> {
@@ -2941,7 +3140,8 @@ class ConvertInterpolateOp
29413140
// coordinate_transformation_mode="asymmetric" will lower to an interpolate
29423141
// op with the non-standard mode="bilinear_asymmetric".
29433142
matchPattern(op.getMode(), m_TorchConstantStr(mode));
2944-
if (mode.substr(0, 8) != "bilinear" && mode.substr(0, 7) != "nearest") {
3143+
if (mode.substr(0, 8) != "bilinear" && mode.substr(0, 7) != "nearest" &&
3144+
mode.substr(0, 5) != "cubic") {
29453145
return failure();
29463146
}
29473147

@@ -3023,13 +3223,18 @@ class ConvertInterpolateOp
30233223
(mode.find(",") == std::string::npos)
30243224
? ""
30253225
: mode.substr(mode.find(",") + 1);
3026-
retVal = NearestInterpolate(
3226+
retVal = nearestInterpolate(
30273227
b, loc, outputSizeIntValues, input, inputSizes,
30283228
ScaleFactorFloatValues, coordTfMode, nearestMode);
30293229
} else if (mode.substr(0, 8) == "bilinear") {
3030-
retVal = BilinearInterpolate(
3230+
retVal = bilinearInterpolate(
30313231
b, op, loc, outputSizeIntValues, input, inputSizes,
30323232
ScaleFactorFloatValues, mode.substr(8));
3233+
} else if (mode.substr(0, 5) == "cubic") {
3234+
3235+
retVal = bicubicInterpolate(
3236+
b, op, loc, outputSizeIntValues, input, inputSizes,
3237+
ScaleFactorFloatValues, mode.substr(5));
30333238
}
30343239
b.create<linalg::YieldOp>(loc, retVal);
30353240
})

0 commit comments

Comments
 (0)