diff --git a/lib/Conversion/TorchToLinalg/Pooling.cpp b/lib/Conversion/TorchToLinalg/Pooling.cpp index 4e03ff578b9f..50ebd80a2e4c 100644 --- a/lib/Conversion/TorchToLinalg/Pooling.cpp +++ b/lib/Conversion/TorchToLinalg/Pooling.cpp @@ -216,170 +216,88 @@ static LogicalResult createPoolingOp( } static Value createMaxUnpoolOp(Operation *op, int64_t poolingDimensionality, - ConversionPatternRewriter &rewriter, - const TypeConverter *typeConverter, Value self, - Value indices, ArrayRef inputSize, - ArrayRef inferredOutSize, - SmallVector &stride, - SmallVector &padding, + OpBuilder &b, Value self, Value indices, + ArrayRef stride, RankedTensorType resType) { - Location loc = op->getLoc(); - Type indexType = rewriter.getIndexType(); - - int64_t outRank = resType.getRank(); - int64_t NC = outRank - poolingDimensionality; + Type indexType = b.getIndexType(); + const int64_t outRank = resType.getRank(); + const int64_t NC = outRank - poolingDimensionality; auto selfType = cast(self.getType()); - auto indicesType = cast(indices.getType()); - - SmallVector outSizePadded; - for (auto &&[i, size] : llvm::enumerate(resType.getShape())) { - if (int64_t(i) < NC) { - outSizePadded.emplace_back(rewriter.create(loc, self, i)); - continue; - } - int64_t pad = padding[i - NC]; - - outSizePadded.emplace_back( - rewriter.create(loc, size + pad)); - } - - // In case if input tensor size is not divisible by stride - // (e.g. pooling_input_size=5, kernel_size=2, stride=2, output_size=2) - // pad self and indices tensors to avoid out of bounds access. - SmallVector expectedInputShape = - llvm::to_vector(resType.getShape().drop_back(poolingDimensionality)); - for (auto &&[str, pad, resSize] : - llvm::zip_equal(stride, padding, inferredOutSize)) - expectedInputShape.emplace_back((resSize + str - 1) / str + pad * 2); - - if (expectedInputShape != selfType.getShape()) { - // TODO: this is probably expensive, and it may be possible to solve by - // cleverly constructing affine maps for the next linalg.generic op, - // but I'm not smart enough to figure this out. - - SmallVector low(outRank, 0); - SmallVector high(NC, 0); - for (auto &&[inpSize, outSize] : llvm::zip_equal( - inputSize, - ArrayRef(expectedInputShape).take_back(poolingDimensionality))) { - high.emplace_back(outSize - inpSize); + Type elementType = selfType.getElementType(); + + // Initialize output tensor with zeros + Value zero = b.create(loc, b.getZeroAttr(elementType)); + Value init = b.create(loc, resType.getShape(), elementType); + init = b.create(loc, zero, init)->getResult(0); + + // Build affine expressions for input and output mappings + // For 2D: input maps (d0,d1,d2,d3,d4,d5) -> (d0,d1,d2,d3), output -> + // (d0,d1,d4,d5) + SmallVector inputExprs, outputExprs; + inputExprs.reserve(outRank); + outputExprs.reserve(outRank); + const int64_t totalDims = outRank + poolingDimensionality; + + for (int64_t i = 0; i < totalDims; ++i) { + AffineExpr dim = b.getAffineDimExpr(i); + if (i < outRank) { + outputExprs.push_back(dim); } - - // Pad the indices tensor with a value which cannot appear in real data - // (-1) so it will never match. In this case we can pad self with any - // value, as it will never affect the output. - Value zero = rewriter.create( - loc, rewriter.getZeroAttr(selfType.getElementType())); - Value invalidIdx = rewriter.create( - loc, rewriter.getIntegerAttr(indicesType.getElementType(), -1)); - self = - torch_to_linalg::getPaddedTensor(op, rewriter, self, low, high, zero); - indices = torch_to_linalg::getPaddedTensor(op, rewriter, indices, low, high, - invalidIdx); - } - - Value init = rewriter.create( - loc, getAsOpFoldResult(outSizePadded), selfType.getElementType()); - - SmallVector inputExprs; - SmallVector outputExprs; - for (auto i : llvm::seq(0, outRank)) { - AffineExpr dim = rewriter.getAffineDimExpr(i); - if (i < NC) { - inputExprs.emplace_back(dim); - } else { - int64_t j = i - NC; - inputExprs.emplace_back(dim.floorDiv(stride[j])); + if (i < NC || i >= outRank) { + inputExprs.push_back(dim); } - outputExprs.emplace_back(dim); } SmallVector indexingMaps = AffineMap::inferFromExprList( - {inputExprs, inputExprs, outputExprs}, rewriter.getContext()); - - SmallVector iteratorTypes(outRank, - utils::IteratorType::parallel); - - auto computeIndex = [&](OpBuilder &b, Location loc) -> Value { - // Next linalg.generic uses identity mapping for the unpooled tensor, - // compute linear index for output element, which we will the compare with - // values which came from indices tensor. - Value ret; - for (auto i : llvm::seq(NC, outRank)) { - Value idx = b.create(loc, i); - // If pool input was padded, adjust indices so they start at 0 in the - // non-padded area. Indices outside non-padded area will make no sense, - // but it doesnt matter as we will cut the padded area later by - // extract_slice. - int64_t pad = padding[i - NC]; - if (pad != 0) { - Value padVal = b.create(loc, pad); - idx = b.create(loc, idx, padVal); - } - - if (!ret) { - ret = idx; + {inputExprs, inputExprs, outputExprs}, b.getContext()); + + // Iterator types: parallel for output dims, reduction for pooling dims + SmallVector iteratorTypes; + iteratorTypes.reserve(totalDims); + iteratorTypes.append(outRank, utils::IteratorType::parallel); + iteratorTypes.append(poolingDimensionality, utils::IteratorType::reduction); + + // Compute linear index from multi-dimensional coordinates (e.g., h*width + w) + auto computeLinearIndex = [&](OpBuilder &b, Location loc) -> Value { + Value linearIndex; + ArrayRef shape = resType.getShape(); + for (int64_t i = NC; i < outRank; ++i) { + Value currentIdx = b.create(loc, i); + if (!linearIndex) { + linearIndex = currentIdx; } else { - Value size = - b.create(loc, resType.getShape()[i]); - ret = b.create(loc, ret, size); - ret = b.create(loc, ret, idx); + Value dimSize = b.create(loc, shape[i]); + linearIndex = b.create(loc, linearIndex, dimSize); + linearIndex = b.create(loc, linearIndex, currentIdx); } } - return ret; + return linearIndex; }; - auto builder = [&](OpBuilder &b, Location loc, ValueRange args) { - // Compute current output linear index and compare it with the value - // from indices arg. - Value input = args[0]; - Value zero = - b.create(loc, rewriter.getZeroAttr(input.getType())); - Value index = b.create(loc, indexType, args[1]); - Value currentIndex = computeIndex(b, loc); - Value cmp = b.create(loc, arith::CmpIPredicate::eq, index, - currentIndex); - Value out = b.create(loc, cmp, input, zero); - b.create(loc, out); + auto bodyBuilder = [&](OpBuilder &b, Location loc, ValueRange args) { + Value inputValue = args[0]; + Value storedIndex = args[1]; + Value currentOutput = args[2]; + + // Convert stored index to index type and compare with current position + Value indexAsIndex = + b.create(loc, indexType, storedIndex); + Value currentLinearIndex = computeLinearIndex(b, loc); + Value isMatch = b.create(loc, arith::CmpIPredicate::eq, + indexAsIndex, currentLinearIndex); + + // Select input value if indices match, otherwise keep current output + Value result = + b.create(loc, isMatch, inputValue, currentOutput); + b.create(loc, result); }; - Value result = - rewriter - .create(loc, - /*resultTensorTypes=*/init.getType(), - /*inputs=*/ValueRange({self, indices}), - /*outputs=*/init, - /*indexingMaps=*/indexingMaps, - /*iteratorTypes=*/iteratorTypes, builder) - .getResult(0); - - if (llvm::any_of(padding, [](int64_t v) { return v != 0; })) { - // MaxPool input was padded, unpad it by taking the slice. - SmallVector offsetVals(NC, rewriter.getI64IntegerAttr(0)); - for (int64_t pad : padding) - offsetVals.emplace_back(rewriter.getI64IntegerAttr(pad)); - - SmallVector sizeVals; - for (auto &&[i, dim] : llvm::enumerate(resType.getShape())) { - if (!ShapedType::isDynamic(dim)) { - sizeVals.emplace_back(rewriter.getI64IntegerAttr(dim)); - continue; - } - - sizeVals.emplace_back(rewriter.create(loc, self, i)); - } - SmallVector stridesVals(outRank, - rewriter.getI64IntegerAttr(1)); - result = rewriter.create(loc, result, offsetVals, - sizeVals, stridesVals); - } - - if (result.getType() != resType) - result = rewriter.create(loc, resType, result); - - return result; + return b + .create(loc, init.getType(), ValueRange{self, indices}, + init, indexingMaps, iteratorTypes, bodyBuilder) + .getResult(0); } namespace { @@ -843,9 +761,8 @@ class ConvertAtenMaxUnpool3dOp final } int64_t poolingDimensionality = 3; - Value result = createMaxUnpoolOp( - op, poolingDimensionality, rewriter, typeConverter, self, indices, - spatialInputShape, inferredOutSize, stride, padding, resType); + Value result = createMaxUnpoolOp(op, poolingDimensionality, rewriter, self, + indices, stride, resType); rewriter.replaceOp(op, result); return success(); @@ -905,9 +822,8 @@ class ConvertAtenMaxUnpool2dOp final SmallVector stride(poolingDimensionality, poolingDimensionality); SmallVector padding(poolingDimensionality, 0); - Value result = createMaxUnpoolOp(op, poolingDimensionality, rewriter, - typeConverter, self, indices, inputSize, - inferredOutSize, stride, padding, resType); + Value result = createMaxUnpoolOp(op, poolingDimensionality, rewriter, self, + indices, stride, resType); rewriter.replaceOp(op, result); return success(); diff --git a/test/Conversion/TorchToLinalg/pooling.mlir b/test/Conversion/TorchToLinalg/pooling.mlir index b95e96c4a461..69b45f29d9aa 100644 --- a/test/Conversion/TorchToLinalg/pooling.mlir +++ b/test/Conversion/TorchToLinalg/pooling.mlir @@ -48,29 +48,38 @@ func.func @forward_max_pool2d(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vt // ----- -// CHECK: #map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2 floordiv 2, d3 floordiv 2)> -// CHECK: #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> -// CHECK-LABEL: func @forward_max_unpool2d +// CHECK: #[[$INPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d5)> +// CHECK: #[[$OUTPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> +// CHECK-LABEL: func.func @forward_max_unpool2d( +// CHECK-SAME: %[[INPUT:.*]]: !torch.vtensor<[2,2,2,4],f32>, +// CHECK-SAME: %[[INDICES:.*]]: !torch.vtensor<[2,2,2,4],si64>) -> !torch.vtensor<[2,2,4,8],f32> { +// CHECK: %[[INDICES_TENSOR:.*]] = torch_c.to_builtin_tensor %[[INDICES]] : !torch.vtensor<[2,2,2,4],si64> -> tensor<2x2x2x4xi64> +// CHECK: %[[INPUT_TENSOR:.*]] = torch_c.to_builtin_tensor %[[INPUT]] : !torch.vtensor<[2,2,2,4],f32> -> tensor<2x2x2x4xf32> +// CHECK: %[[C8:.*]] = torch.constant.int 8 +// CHECK: %[[C4:.*]] = torch.constant.int 4 +// CHECK: %[[OUTPUT_SIZE:.*]] = torch.prim.ListConstruct %[[C4]], %[[C8]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[EMPTY_OUTPUT:.*]] = tensor.empty() : tensor<2x2x4x8xf32> +// CHECK: %[[FILLED_OUTPUT:.*]] = linalg.fill ins(%[[ZERO]] : f32) outs(%[[EMPTY_OUTPUT]] : tensor<2x2x4x8xf32>) -> tensor<2x2x4x8xf32> +// CHECK: %[[RESULT:.*]] = linalg.generic {indexing_maps = [#[[$INPUT_MAP]], #[[$INPUT_MAP]], #[[$OUTPUT_MAP]]], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%[[INPUT_TENSOR]], %[[INDICES_TENSOR]] : tensor<2x2x2x4xf32>, tensor<2x2x2x4xi64>) outs(%[[FILLED_OUTPUT]] : tensor<2x2x4x8xf32>) { +// CHECK: ^bb0(%[[INPUT_VAL:.*]]: f32, %[[INDEX_VAL:.*]]: i64, %[[OUTPUT_VAL:.*]]: f32): +// CHECK: %[[INDEX_CAST:.*]] = arith.index_cast %[[INDEX_VAL]] : i64 to index +// CHECK: %[[ROW_IDX:.*]] = linalg.index 2 : index +// CHECK: %[[COL_IDX:.*]] = linalg.index 3 : index +// CHECK: %[[WIDTH:.*]] = arith.constant 8 : index +// CHECK: %[[LINEAR_IDX:.*]] = arith.muli %[[ROW_IDX]], %[[WIDTH]] : index +// CHECK: %[[FLAT_IDX:.*]] = arith.addi %[[LINEAR_IDX]], %[[COL_IDX]] : index +// CHECK: %[[IS_MATCH:.*]] = arith.cmpi eq, %[[INDEX_CAST]], %[[FLAT_IDX]] : index +// CHECK: %[[SELECTED:.*]] = arith.select %[[IS_MATCH]], %[[INPUT_VAL]], %[[OUTPUT_VAL]] : f32 +// CHECK: linalg.yield %[[SELECTED]] : f32 +// CHECK: } -> tensor<2x2x4x8xf32> +// CHECK: %[[OUTPUT_TENSOR:.*]] = torch_c.from_builtin_tensor %[[RESULT]] : tensor<2x2x4x8xf32> -> !torch.vtensor<[2,2,4,8],f32> +// CHECK: return %[[OUTPUT_TENSOR]] : !torch.vtensor<[2,2,4,8],f32> +// CHECK: } func.func @forward_max_unpool2d(%arg0: !torch.vtensor<[2,2,2,4],f32>, %arg1: !torch.vtensor<[2,2,2,4],si64>) -> !torch.vtensor<[2,2,4,8],f32> { %int8 = torch.constant.int 8 %int4 = torch.constant.int 4 %0 = torch.prim.ListConstruct %int4, %int8 : (!torch.int, !torch.int) -> !torch.list - // CHECK: = linalg.generic - // CHECK-SAME: indexing_maps = [#map, #map, #map1] - // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"] - // CHECK: ins( - // CHECK: outs( - // CHECK: ^bb0( - // CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 - // CHECK: %[[CAST:.*]] = arith.index_cast %{{.*}} : i64 to index - // CHECK: %[[IDX2:.*]] = linalg.index 2 : index - // CHECK: %[[IDX3:.*]] = linalg.index 3 : index - // CHECK: %[[C8_2:.*]] = arith.constant 8 : index - // CHECK: %[[MUL:.*]] = arith.muli %[[IDX2]], %[[C8_2]] : index - // CHECK: %[[ADD:.*]] = arith.addi %[[MUL]], %[[IDX3]] : index - // CHECK: %[[CMP:.*]] = arith.cmpi eq, %[[CAST]], %[[ADD]] : index - // CHECK: %[[SEL:.*]] = arith.select %[[CMP]], %{{.*}}, %[[CST]] : f32 - // CHECK: linalg.yield %[[SEL]] : f32 %1 = torch.aten.max_unpool2d %arg0, %arg1, %0 : !torch.vtensor<[2,2,2,4],f32>, !torch.vtensor<[2,2,2,4],si64>, !torch.list -> !torch.vtensor<[2,2,4,8],f32> return %1 : !torch.vtensor<[2,2,4,8],f32> }