Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
224 changes: 70 additions & 154 deletions lib/Conversion/TorchToLinalg/Pooling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -216,170 +216,88 @@ static LogicalResult createPoolingOp(
}

static Value createMaxUnpoolOp(Operation *op, int64_t poolingDimensionality,
ConversionPatternRewriter &rewriter,
const TypeConverter *typeConverter, Value self,
Value indices, ArrayRef<int64_t> inputSize,
ArrayRef<int64_t> inferredOutSize,
SmallVector<int64_t> &stride,
SmallVector<int64_t> &padding,
OpBuilder &b, Value self, Value indices,
ArrayRef<int64_t> stride,
RankedTensorType resType) {

Location loc = op->getLoc();
Type indexType = rewriter.getIndexType();

int64_t outRank = resType.getRank();
int64_t NC = outRank - poolingDimensionality;
Type indexType = b.getIndexType();
const int64_t outRank = resType.getRank();
const int64_t NC = outRank - poolingDimensionality;

auto selfType = cast<RankedTensorType>(self.getType());
auto indicesType = cast<RankedTensorType>(indices.getType());

SmallVector<Value> outSizePadded;
for (auto &&[i, size] : llvm::enumerate(resType.getShape())) {
if (int64_t(i) < NC) {
outSizePadded.emplace_back(rewriter.create<tensor::DimOp>(loc, self, i));
continue;
}
int64_t pad = padding[i - NC];

outSizePadded.emplace_back(
rewriter.create<arith::ConstantIndexOp>(loc, size + pad));
}

// In case if input tensor size is not divisible by stride
// (e.g. pooling_input_size=5, kernel_size=2, stride=2, output_size=2)
// pad self and indices tensors to avoid out of bounds access.
SmallVector<int64_t> expectedInputShape =
llvm::to_vector(resType.getShape().drop_back(poolingDimensionality));
for (auto &&[str, pad, resSize] :
llvm::zip_equal(stride, padding, inferredOutSize))
expectedInputShape.emplace_back((resSize + str - 1) / str + pad * 2);

if (expectedInputShape != selfType.getShape()) {
// TODO: this is probably expensive, and it may be possible to solve by
// cleverly constructing affine maps for the next linalg.generic op,
// but I'm not smart enough to figure this out.

SmallVector<int64_t> low(outRank, 0);
SmallVector<int64_t> high(NC, 0);
for (auto &&[inpSize, outSize] : llvm::zip_equal(
inputSize,
ArrayRef(expectedInputShape).take_back(poolingDimensionality))) {
high.emplace_back(outSize - inpSize);
Type elementType = selfType.getElementType();

// Initialize output tensor with zeros
Value zero = b.create<arith::ConstantOp>(loc, b.getZeroAttr(elementType));
Value init = b.create<tensor::EmptyOp>(loc, resType.getShape(), elementType);
init = b.create<linalg::FillOp>(loc, zero, init)->getResult(0);

// Build affine expressions for input and output mappings
// For 2D: input maps (d0,d1,d2,d3,d4,d5) -> (d0,d1,d2,d3), output ->
// (d0,d1,d4,d5)
SmallVector<AffineExpr> inputExprs, outputExprs;
inputExprs.reserve(outRank);
outputExprs.reserve(outRank);
const int64_t totalDims = outRank + poolingDimensionality;

for (int64_t i = 0; i < totalDims; ++i) {
AffineExpr dim = b.getAffineDimExpr(i);
if (i < outRank) {
outputExprs.push_back(dim);
}

// Pad the indices tensor with a value which cannot appear in real data
// (-1) so it will never match. In this case we can pad self with any
// value, as it will never affect the output.
Value zero = rewriter.create<arith::ConstantOp>(
loc, rewriter.getZeroAttr(selfType.getElementType()));
Value invalidIdx = rewriter.create<arith::ConstantOp>(
loc, rewriter.getIntegerAttr(indicesType.getElementType(), -1));
self =
torch_to_linalg::getPaddedTensor(op, rewriter, self, low, high, zero);
indices = torch_to_linalg::getPaddedTensor(op, rewriter, indices, low, high,
invalidIdx);
}

Value init = rewriter.create<tensor::EmptyOp>(
loc, getAsOpFoldResult(outSizePadded), selfType.getElementType());

SmallVector<AffineExpr> inputExprs;
SmallVector<AffineExpr> outputExprs;
for (auto i : llvm::seq<int64_t>(0, outRank)) {
AffineExpr dim = rewriter.getAffineDimExpr(i);
if (i < NC) {
inputExprs.emplace_back(dim);
} else {
int64_t j = i - NC;
inputExprs.emplace_back(dim.floorDiv(stride[j]));
if (i < NC || i >= outRank) {
inputExprs.push_back(dim);
}
outputExprs.emplace_back(dim);
}

SmallVector<AffineMap> indexingMaps = AffineMap::inferFromExprList(
{inputExprs, inputExprs, outputExprs}, rewriter.getContext());

SmallVector<utils::IteratorType> iteratorTypes(outRank,
utils::IteratorType::parallel);

auto computeIndex = [&](OpBuilder &b, Location loc) -> Value {
// Next linalg.generic uses identity mapping for the unpooled tensor,
// compute linear index for output element, which we will the compare with
// values which came from indices tensor.
Value ret;
for (auto i : llvm::seq<int64_t>(NC, outRank)) {
Value idx = b.create<linalg::IndexOp>(loc, i);
// If pool input was padded, adjust indices so they start at 0 in the
// non-padded area. Indices outside non-padded area will make no sense,
// but it doesnt matter as we will cut the padded area later by
// extract_slice.
int64_t pad = padding[i - NC];
if (pad != 0) {
Value padVal = b.create<arith::ConstantIndexOp>(loc, pad);
idx = b.create<arith::SubIOp>(loc, idx, padVal);
}

if (!ret) {
ret = idx;
{inputExprs, inputExprs, outputExprs}, b.getContext());

// Iterator types: parallel for output dims, reduction for pooling dims
SmallVector<utils::IteratorType> iteratorTypes;
iteratorTypes.reserve(totalDims);
iteratorTypes.append(outRank, utils::IteratorType::parallel);
iteratorTypes.append(poolingDimensionality, utils::IteratorType::reduction);

// Compute linear index from multi-dimensional coordinates (e.g., h*width + w)
auto computeLinearIndex = [&](OpBuilder &b, Location loc) -> Value {
Value linearIndex;
ArrayRef<int64_t> shape = resType.getShape();
for (int64_t i = NC; i < outRank; ++i) {
Value currentIdx = b.create<linalg::IndexOp>(loc, i);
if (!linearIndex) {
linearIndex = currentIdx;
} else {
Value size =
b.create<arith::ConstantIndexOp>(loc, resType.getShape()[i]);
ret = b.create<arith::MulIOp>(loc, ret, size);
ret = b.create<arith::AddIOp>(loc, ret, idx);
Value dimSize = b.create<arith::ConstantIndexOp>(loc, shape[i]);
linearIndex = b.create<arith::MulIOp>(loc, linearIndex, dimSize);
linearIndex = b.create<arith::AddIOp>(loc, linearIndex, currentIdx);
}
}
return ret;
return linearIndex;
};

auto builder = [&](OpBuilder &b, Location loc, ValueRange args) {
// Compute current output linear index and compare it with the value
// from indices arg.
Value input = args[0];
Value zero =
b.create<arith::ConstantOp>(loc, rewriter.getZeroAttr(input.getType()));
Value index = b.create<arith::IndexCastOp>(loc, indexType, args[1]);
Value currentIndex = computeIndex(b, loc);
Value cmp = b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, index,
currentIndex);
Value out = b.create<arith::SelectOp>(loc, cmp, input, zero);
b.create<linalg::YieldOp>(loc, out);
auto bodyBuilder = [&](OpBuilder &b, Location loc, ValueRange args) {
Value inputValue = args[0];
Value storedIndex = args[1];
Value currentOutput = args[2];

// Convert stored index to index type and compare with current position
Value indexAsIndex =
b.create<arith::IndexCastOp>(loc, indexType, storedIndex);
Value currentLinearIndex = computeLinearIndex(b, loc);
Value isMatch = b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
indexAsIndex, currentLinearIndex);

// Select input value if indices match, otherwise keep current output
Value result =
b.create<arith::SelectOp>(loc, isMatch, inputValue, currentOutput);
b.create<linalg::YieldOp>(loc, result);
};

Value result =
rewriter
.create<linalg::GenericOp>(loc,
/*resultTensorTypes=*/init.getType(),
/*inputs=*/ValueRange({self, indices}),
/*outputs=*/init,
/*indexingMaps=*/indexingMaps,
/*iteratorTypes=*/iteratorTypes, builder)
.getResult(0);

if (llvm::any_of(padding, [](int64_t v) { return v != 0; })) {
// MaxPool input was padded, unpad it by taking the slice.
SmallVector<OpFoldResult> offsetVals(NC, rewriter.getI64IntegerAttr(0));
for (int64_t pad : padding)
offsetVals.emplace_back(rewriter.getI64IntegerAttr(pad));

SmallVector<OpFoldResult> sizeVals;
for (auto &&[i, dim] : llvm::enumerate(resType.getShape())) {
if (!ShapedType::isDynamic(dim)) {
sizeVals.emplace_back(rewriter.getI64IntegerAttr(dim));
continue;
}

sizeVals.emplace_back(rewriter.create<tensor::DimOp>(loc, self, i));
}
SmallVector<OpFoldResult> stridesVals(outRank,
rewriter.getI64IntegerAttr(1));
result = rewriter.create<tensor::ExtractSliceOp>(loc, result, offsetVals,
sizeVals, stridesVals);
}

if (result.getType() != resType)
result = rewriter.create<tensor::CastOp>(loc, resType, result);

return result;
return b
.create<linalg::GenericOp>(loc, init.getType(), ValueRange{self, indices},
init, indexingMaps, iteratorTypes, bodyBuilder)
.getResult(0);
}

namespace {
Expand Down Expand Up @@ -843,9 +761,8 @@ class ConvertAtenMaxUnpool3dOp final
}

int64_t poolingDimensionality = 3;
Value result = createMaxUnpoolOp(
op, poolingDimensionality, rewriter, typeConverter, self, indices,
spatialInputShape, inferredOutSize, stride, padding, resType);
Value result = createMaxUnpoolOp(op, poolingDimensionality, rewriter, self,
indices, stride, resType);

rewriter.replaceOp(op, result);
return success();
Expand Down Expand Up @@ -905,9 +822,8 @@ class ConvertAtenMaxUnpool2dOp final
SmallVector<int64_t> stride(poolingDimensionality, poolingDimensionality);
SmallVector<int64_t> padding(poolingDimensionality, 0);

Value result = createMaxUnpoolOp(op, poolingDimensionality, rewriter,
typeConverter, self, indices, inputSize,
inferredOutSize, stride, padding, resType);
Value result = createMaxUnpoolOp(op, poolingDimensionality, rewriter, self,
indices, stride, resType);

rewriter.replaceOp(op, result);
return success();
Expand Down
47 changes: 28 additions & 19 deletions test/Conversion/TorchToLinalg/pooling.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -48,29 +48,38 @@ func.func @forward_max_pool2d(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vt

// -----

// CHECK: #map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2 floordiv 2, d3 floordiv 2)>
// CHECK: #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
// CHECK-LABEL: func @forward_max_unpool2d
// CHECK: #[[$INPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d5)>
// CHECK: #[[$OUTPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
// CHECK-LABEL: func.func @forward_max_unpool2d(
// CHECK-SAME: %[[INPUT:.*]]: !torch.vtensor<[2,2,2,4],f32>,
// CHECK-SAME: %[[INDICES:.*]]: !torch.vtensor<[2,2,2,4],si64>) -> !torch.vtensor<[2,2,4,8],f32> {
// CHECK: %[[INDICES_TENSOR:.*]] = torch_c.to_builtin_tensor %[[INDICES]] : !torch.vtensor<[2,2,2,4],si64> -> tensor<2x2x2x4xi64>
// CHECK: %[[INPUT_TENSOR:.*]] = torch_c.to_builtin_tensor %[[INPUT]] : !torch.vtensor<[2,2,2,4],f32> -> tensor<2x2x2x4xf32>
// CHECK: %[[C8:.*]] = torch.constant.int 8
// CHECK: %[[C4:.*]] = torch.constant.int 4
// CHECK: %[[OUTPUT_SIZE:.*]] = torch.prim.ListConstruct %[[C4]], %[[C8]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[EMPTY_OUTPUT:.*]] = tensor.empty() : tensor<2x2x4x8xf32>
// CHECK: %[[FILLED_OUTPUT:.*]] = linalg.fill ins(%[[ZERO]] : f32) outs(%[[EMPTY_OUTPUT]] : tensor<2x2x4x8xf32>) -> tensor<2x2x4x8xf32>
// CHECK: %[[RESULT:.*]] = linalg.generic {indexing_maps = [#[[$INPUT_MAP]], #[[$INPUT_MAP]], #[[$OUTPUT_MAP]]], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%[[INPUT_TENSOR]], %[[INDICES_TENSOR]] : tensor<2x2x2x4xf32>, tensor<2x2x2x4xi64>) outs(%[[FILLED_OUTPUT]] : tensor<2x2x4x8xf32>) {
// CHECK: ^bb0(%[[INPUT_VAL:.*]]: f32, %[[INDEX_VAL:.*]]: i64, %[[OUTPUT_VAL:.*]]: f32):
// CHECK: %[[INDEX_CAST:.*]] = arith.index_cast %[[INDEX_VAL]] : i64 to index
// CHECK: %[[ROW_IDX:.*]] = linalg.index 2 : index
// CHECK: %[[COL_IDX:.*]] = linalg.index 3 : index
// CHECK: %[[WIDTH:.*]] = arith.constant 8 : index
// CHECK: %[[LINEAR_IDX:.*]] = arith.muli %[[ROW_IDX]], %[[WIDTH]] : index
// CHECK: %[[FLAT_IDX:.*]] = arith.addi %[[LINEAR_IDX]], %[[COL_IDX]] : index
// CHECK: %[[IS_MATCH:.*]] = arith.cmpi eq, %[[INDEX_CAST]], %[[FLAT_IDX]] : index
// CHECK: %[[SELECTED:.*]] = arith.select %[[IS_MATCH]], %[[INPUT_VAL]], %[[OUTPUT_VAL]] : f32
// CHECK: linalg.yield %[[SELECTED]] : f32
// CHECK: } -> tensor<2x2x4x8xf32>
// CHECK: %[[OUTPUT_TENSOR:.*]] = torch_c.from_builtin_tensor %[[RESULT]] : tensor<2x2x4x8xf32> -> !torch.vtensor<[2,2,4,8],f32>
// CHECK: return %[[OUTPUT_TENSOR]] : !torch.vtensor<[2,2,4,8],f32>
// CHECK: }
func.func @forward_max_unpool2d(%arg0: !torch.vtensor<[2,2,2,4],f32>, %arg1: !torch.vtensor<[2,2,2,4],si64>) -> !torch.vtensor<[2,2,4,8],f32> {
%int8 = torch.constant.int 8
%int4 = torch.constant.int 4
%0 = torch.prim.ListConstruct %int4, %int8 : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: = linalg.generic
// CHECK-SAME: indexing_maps = [#map, #map, #map1]
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"]
// CHECK: ins(
// CHECK: outs(
// CHECK: ^bb0(
// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[CAST:.*]] = arith.index_cast %{{.*}} : i64 to index
// CHECK: %[[IDX2:.*]] = linalg.index 2 : index
// CHECK: %[[IDX3:.*]] = linalg.index 3 : index
// CHECK: %[[C8_2:.*]] = arith.constant 8 : index
// CHECK: %[[MUL:.*]] = arith.muli %[[IDX2]], %[[C8_2]] : index
// CHECK: %[[ADD:.*]] = arith.addi %[[MUL]], %[[IDX3]] : index
// CHECK: %[[CMP:.*]] = arith.cmpi eq, %[[CAST]], %[[ADD]] : index
// CHECK: %[[SEL:.*]] = arith.select %[[CMP]], %{{.*}}, %[[CST]] : f32
// CHECK: linalg.yield %[[SEL]] : f32
%1 = torch.aten.max_unpool2d %arg0, %arg1, %0 : !torch.vtensor<[2,2,2,4],f32>, !torch.vtensor<[2,2,2,4],si64>, !torch.list<int> -> !torch.vtensor<[2,2,4,8],f32>
return %1 : !torch.vtensor<[2,2,4,8],f32>
}
Expand Down
Loading