Skip to content

Commit 7cdb3ce

Browse files
rsudermanmemfrob
authored andcommitted
[mlir][tosa] Add tosa.reduce_any and tosa.reduce_all linalg lowering
Added lowerings for Tosa's reduce boolean operations. This includes a fix to maintain the output rank of reduce operations. Reviewed By: silvas Differential Revision: https://reviews.llvm.org/D99228
1 parent d19ca62 commit 7cdb3ce

File tree

2 files changed

+64
-23
lines changed

2 files changed

+64
-23
lines changed

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -514,6 +514,12 @@ static Attribute createInitialValueForReduceOp(Operation *op, Type elementTy,
514514
return rewriter.getIntegerAttr(
515515
elementTy, APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth()));
516516

517+
if (isa<tosa::ReduceAllOp>(op) && elementTy.isInteger(1))
518+
return rewriter.getIntegerAttr(elementTy, APInt::getAllOnesValue(1));
519+
520+
if (isa<tosa::ReduceAnyOp>(op) && elementTy.isInteger(1))
521+
return rewriter.getIntegerAttr(elementTy, APInt::getNullValue(1));
522+
517523
if (isa<tosa::ArgMaxOp>(op) && elementTy.isa<FloatType>())
518524
return rewriter.getFloatAttr(
519525
elementTy, APFloat::getLargest(
@@ -573,6 +579,12 @@ static Value createLinalgBodyCalculationForReduceOp(Operation *op,
573579
return rewriter.create<mlir::SelectOp>(loc, predicate, args[0], args[1]);
574580
}
575581

582+
if (isa<tosa::ReduceAllOp>(op) && elementTy.isInteger(1))
583+
return rewriter.create<mlir::AndOp>(loc, args);
584+
585+
if (isa<tosa::ReduceAnyOp>(op) && elementTy.isInteger(1))
586+
return rewriter.create<mlir::OrOp>(loc, args);
587+
576588
return {};
577589
}
578590

@@ -613,6 +625,8 @@ static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis,
613625
: getParallelIteratorTypeName());
614626
if (axis != i)
615627
dstExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext()));
628+
else
629+
dstExprs.push_back(rewriter.getAffineConstantExpr(0));
616630
}
617631

618632
bool didEncounterError = false;
@@ -1419,7 +1433,8 @@ void mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns(
14191433
PointwiseConverter<tosa::CeilOp>, PointwiseConverter<tosa::FloorOp>,
14201434
PointwiseConverter<tosa::ClampOp>, PointwiseConverter<tosa::ReluNOp>,
14211435
PointwiseConverter<tosa::SigmoidOp>, IdentityNConverter<tosa::IdentityOp>,
1422-
IdentityNConverter<tosa::IdentityNOp>, ReduceConverter<tosa::ReduceMinOp>,
1436+
IdentityNConverter<tosa::IdentityNOp>, ReduceConverter<tosa::ReduceAllOp>,
1437+
ReduceConverter<tosa::ReduceAnyOp>, ReduceConverter<tosa::ReduceMinOp>,
14231438
ReduceConverter<tosa::ReduceMaxOp>, ReduceConverter<tosa::ReduceSumOp>,
14241439
ReduceConverter<tosa::ReduceProdOp>, ArgMaxConverter, ConcatConverter,
14251440
PadConverter, ReshapeConverter, RescaleConverter, ReverseConverter,

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir

Lines changed: 48 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -442,98 +442,124 @@ func @test_transpose(%arg0: tensor<1x2x3xi32>) -> () {
442442
// -----
443443

444444
// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
445-
// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d1)>
446-
// CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0)>
445+
// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (0, d1)>
446+
// CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0, 0)>
447447

448448
// CHECK-LABEL: @reduce_float
449449
// CHECK-SAME: [[ARG0:%.+]]: tensor<5x4xf32>
450450
func @reduce_float(%arg0: tensor<5x4xf32>) -> () {
451-
// CHECK: [[INIT:%.+]] = linalg.init_tensor [4]
451+
// CHECK: [[INIT:%.+]] = linalg.init_tensor [1, 4]
452452
// CHECK: [[CST0:%.+]] = constant 0.0
453453
// CHECK: [[FILL:%.+]] = linalg.fill([[INIT]], [[CST0]])
454-
// CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["reduction", "parallel"]} ins([[ARG0]] : tensor<5x4xf32>) outs([[FILL]] : tensor<4xf32>)
454+
// CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["reduction", "parallel"]} ins([[ARG0]] : tensor<5x4xf32>) outs([[FILL]] : tensor<1x4xf32>)
455455
// CHECK: ^bb0(%arg1: f32, %arg2: f32)
456456
// CHECK: [[RES:%.+]] = addf %arg1, %arg2 : f32
457457
// CHECK: linalg.yield [[RES]] : f32
458-
%0 = "tosa.reduce_sum"(%arg0) {axis = 0 : i64} : (tensor<5x4xf32>) -> tensor<4xf32>
458+
%0 = "tosa.reduce_sum"(%arg0) {axis = 0 : i64} : (tensor<5x4xf32>) -> tensor<1x4xf32>
459459

460-
// CHECK: [[INIT:%.+]] = linalg.init_tensor [5]
460+
// CHECK: [[INIT:%.+]] = linalg.init_tensor [5, 1]
461461
// CHECK: [[CST0:%.+]] = constant 0.0
462462
// CHECK: [[FILL:%.+]] = linalg.fill([[INIT]], [[CST0]])
463-
// CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP2]]], iterator_types = ["parallel", "reduction"]} ins([[ARG0]] : tensor<5x4xf32>) outs([[FILL]] : tensor<5xf32>)
463+
// CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP2]]], iterator_types = ["parallel", "reduction"]} ins([[ARG0]] : tensor<5x4xf32>) outs([[FILL]] : tensor<5x1xf32>)
464464
// CHECK: ^bb0(%arg1: f32, %arg2: f32)
465465
// CHECK: [[RES:%.+]] = addf %arg1, %arg2 : f32
466466
// CHECK: linalg.yield [[RES]] : f32
467-
%1 = "tosa.reduce_sum"(%arg0) {axis = 1 : i64} : (tensor<5x4xf32>) -> tensor<5xf32>
467+
%1 = "tosa.reduce_sum"(%arg0) {axis = 1 : i64} : (tensor<5x4xf32>) -> tensor<5x1xf32>
468468

469469
// CHECK: constant 1.0
470470
// CHECK: linalg.fill
471471
// CHECK: linalg.generic
472472
// CHECK: mulf
473-
%2 = "tosa.reduce_prod"(%arg0) {axis = 0 : i64} : (tensor<5x4xf32>) -> tensor<4xf32>
473+
%2 = "tosa.reduce_prod"(%arg0) {axis = 0 : i64} : (tensor<5x4xf32>) -> tensor<1x4xf32>
474474

475475
// CHECK: constant 3.40282347E+38 : f32
476476
// CHECK: linalg.fill
477477
// CHECK: linalg.generic
478478
// CHECK: cmpf olt
479479
// CHECK: select
480-
%3 = "tosa.reduce_min"(%arg0) {axis = 0 : i64} : (tensor<5x4xf32>) -> tensor<4xf32>
480+
%3 = "tosa.reduce_min"(%arg0) {axis = 0 : i64} : (tensor<5x4xf32>) -> tensor<1x4xf32>
481481

482482
// CHECK: constant -3.40282347E+38 : f32
483483
// CHECK: linalg.fill
484484
// CHECK: linalg.generic
485485
// CHECK: cmpf ogt
486486
// CHECK: select
487-
%4 = "tosa.reduce_max"(%arg0) {axis = 0 : i64} : (tensor<5x4xf32>) -> tensor<4xf32>
487+
%4 = "tosa.reduce_max"(%arg0) {axis = 0 : i64} : (tensor<5x4xf32>) -> tensor<1x4xf32>
488488
return
489489
}
490490

491491
// -----
492492

493493
// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
494-
// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d1)>
495-
// CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0)>
494+
// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (0, d1)>
495+
// CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0, 0)>
496496

497497
// CHECK-LABEL: @reduce_int
498498
// CHECK-SAME: [[ARG0:%.+]]: tensor<5x4xi32>
499499
func @reduce_int(%arg0: tensor<5x4xi32>) -> () {
500-
// CHECK: [[INIT:%.+]] = linalg.init_tensor [4]
500+
// CHECK: [[INIT:%.+]] = linalg.init_tensor [1, 4]
501501
// CHECK: [[CST0:%.+]] = constant 0
502502
// CHECK: [[FILL:%.+]] = linalg.fill([[INIT]], [[CST0]])
503-
// CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["reduction", "parallel"]} ins([[ARG0]] : tensor<5x4xi32>) outs([[FILL]] : tensor<4xi32>)
503+
// CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["reduction", "parallel"]} ins([[ARG0]] : tensor<5x4xi32>) outs([[FILL]] : tensor<1x4xi32>)
504504
// CHECK: ^bb0(%arg1: i32, %arg2: i32)
505505
// CHECK: [[RES:%.+]] = addi %arg1, %arg2 : i32
506506
// CHECK: linalg.yield [[RES]] : i32
507-
%0 = "tosa.reduce_sum"(%arg0) {axis = 0 : i64} : (tensor<5x4xi32>) -> tensor<4xi32>
507+
%0 = "tosa.reduce_sum"(%arg0) {axis = 0 : i64} : (tensor<5x4xi32>) -> tensor<1x4xi32>
508508

509-
// CHECK: [[INIT:%.+]] = linalg.init_tensor [5]
509+
// CHECK: [[INIT:%.+]] = linalg.init_tensor [5, 1]
510510
// CHECK: [[CST0:%.+]] = constant 0
511511
// CHECK: [[FILL:%.+]] = linalg.fill([[INIT]], [[CST0]])
512-
// CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP2]]], iterator_types = ["parallel", "reduction"]} ins([[ARG0]] : tensor<5x4xi32>) outs([[FILL]] : tensor<5xi32>)
512+
// CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP2]]], iterator_types = ["parallel", "reduction"]} ins([[ARG0]] : tensor<5x4xi32>) outs([[FILL]] : tensor<5x1xi32>)
513513
// CHECK: ^bb0(%arg1: i32, %arg2: i32)
514514
// CHECK: [[RES:%.+]] = addi %arg1, %arg2 : i32
515515
// CHECK: linalg.yield [[RES]] : i32
516-
%1 = "tosa.reduce_sum"(%arg0) {axis = 1 : i64} : (tensor<5x4xi32>) -> tensor<5xi32>
516+
%1 = "tosa.reduce_sum"(%arg0) {axis = 1 : i64} : (tensor<5x4xi32>) -> tensor<5x1xi32>
517517

518518
// CHECK: constant 1
519519
// CHECK: linalg.fill
520520
// CHECK: linalg.generic
521521
// CHECK: muli
522-
%2 = "tosa.reduce_prod"(%arg0) {axis = 0 : i64} : (tensor<5x4xi32>) -> tensor<4xi32>
522+
%2 = "tosa.reduce_prod"(%arg0) {axis = 0 : i64} : (tensor<5x4xi32>) -> tensor<1x4xi32>
523523

524524
// CHECK: constant 2147483647 : i32
525525
// CHECK: linalg.fill
526526
// CHECK: linalg.generic
527527
// CHECK: cmpi slt
528528
// CHECK: select
529-
%3 = "tosa.reduce_min"(%arg0) {axis = 0 : i64} : (tensor<5x4xi32>) -> tensor<4xi32>
529+
%3 = "tosa.reduce_min"(%arg0) {axis = 0 : i64} : (tensor<5x4xi32>) -> tensor<1x4xi32>
530530

531531
// CHECK: constant -2147483648 : i32
532532
// CHECK: linalg.fill
533533
// CHECK: linalg.generic
534534
// CHECK: cmpi sgt
535535
// CHECK: select
536-
%4 = "tosa.reduce_max"(%arg0) {axis = 0 : i64} : (tensor<5x4xi32>) -> tensor<4xi32>
536+
%4 = "tosa.reduce_max"(%arg0) {axis = 0 : i64} : (tensor<5x4xi32>) -> tensor<1x4xi32>
537+
return
538+
}
539+
540+
// -----
541+
542+
// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
543+
// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (0, d1)>
544+
545+
// CHECK-LABEL: @reduce_bool
546+
// CHECK-SAME: [[ARG0:%.+]]: tensor<5x4xi1>
547+
func @reduce_bool(%arg0: tensor<5x4xi1>) -> () {
548+
// CHECK: [[INIT:%.+]] = linalg.init_tensor [1, 4]
549+
// CHECK: [[CST0:%.+]] = constant true
550+
// CHECK: [[FILL:%.+]] = linalg.fill([[INIT]], [[CST0]])
551+
// CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["reduction", "parallel"]} ins([[ARG0]] : tensor<5x4xi1>) outs([[FILL]] : tensor<1x4xi1>)
552+
// CHECK: ^bb0(%arg1: i1, %arg2: i1)
553+
// CHECK: [[RES:%.+]] = and %arg1, %arg2 : i1
554+
// CHECK: linalg.yield [[RES]] : i1
555+
%0 = "tosa.reduce_all"(%arg0) {axis = 0 : i64} : (tensor<5x4xi1>) -> tensor<1x4xi1>
556+
557+
// CHECK: constant false
558+
// CHECK: linalg.fill
559+
// CHECK: linalg.generic
560+
// CHECK: or
561+
%1 = "tosa.reduce_any"(%arg0) {axis = 0 : i64} : (tensor<5x4xi1>) -> tensor<1x4xi1>
562+
537563
return
538564
}
539565

0 commit comments

Comments
 (0)