Skip to content

Commit 4c3b0a6

Browse files
[mlir][tosa] Fix Map for Bias Broadcast (#89059)
1 parent effb2f1 commit 4c3b0a6

File tree

2 files changed

+25
-3
lines changed

2 files changed

+25
-3
lines changed

mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -101,9 +101,18 @@ static mlir::Value linalgBroadcastAndMaybeExtSI(PatternRewriter &rewriter,
101101
// The source tensor is broadcast to all the outer dimensions of the
102102
// result tensor.
103103
SmallVector<AffineExpr> sourceDims;
104-
for (auto dim : llvm::seq<int64_t>(0, sourceRank)) {
105-
auto expr = rewriter.getAffineDimExpr(dim + resultRank - sourceRank);
106-
sourceDims.push_back(expr);
104+
// In the case of a rank one source tensor with a single element TOSA
105+
// specifies that the value be broadcast meaning we need an edge case for a
106+
// constant map.
107+
assert(sourceTy.hasStaticShape() &&
108+
"Dynamic broadcasting shapes not supported!");
109+
if (sourceRank == 1 && sourceTy.getDimSize(0) == 1) {
110+
sourceDims.push_back(rewriter.getAffineConstantExpr(0));
111+
} else {
112+
for (auto dim : llvm::seq<int64_t>(0, sourceRank)) {
113+
auto expr = rewriter.getAffineDimExpr(dim + resultRank - sourceRank);
114+
sourceDims.push_back(expr);
115+
}
107116
}
108117

109118
// Creating maps for the input and output of the broacast-like generic op.

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -503,6 +503,19 @@ func.func @avg_pool_dyn(%arg0: tensor<?x6x34x62xf32>) -> (tensor<?x5x33x62xf32>)
503503

504504
// -----
505505

506+
// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (0)>
507+
// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
508+
509+
// CHECK-LABEL: @conv2d_scalar_bias_f32
510+
func.func @conv2d_scalar_bias_f32(%input: tensor<1x49x42x27xf32>, %weights: tensor<28x3x3x27xf32>, %bias: tensor<1xf32>) -> () {
511+
// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x45x40x28xf32>
512+
// CHECK: %[[BROADCAST:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%[[INIT]] : tensor<1x45x40x28xf32>) {
513+
%0 = tosa.conv2d %input, %weights, %bias {pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 2, 1>} : (tensor<1x49x42x27xf32>, tensor<28x3x3x27xf32>, tensor<1xf32>) -> tensor<1x45x40x28xf32>
514+
return
515+
}
516+
517+
// -----
518+
506519
// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d3)>
507520
// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
508521

0 commit comments

Comments
 (0)