Skip to content

Commit b933507

Browse files
committed
fix short form print and update remaining tests
1 parent 5a1622d commit b933507

File tree

6 files changed

+27
-21
lines changed

6 files changed

+27
-21
lines changed

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1573,7 +1573,11 @@ ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {
15731573
return success();
15741574
}
15751575

1576-
static bool canUseShortForm(Block *body, bool initFirst = false) {
1576+
static bool canUseShortForm(Block *body, bool initFirst = false,
1577+
bool mapInit = true) {
1578+
// `intFirst == true` implies that we want to map init arg
1579+
if (initFirst && !mapInit)
1580+
return false;
15771581
// Check if the body can be printed in short form. The following 4 conditions
15781582
// must be satisfied:
15791583

@@ -1585,7 +1589,7 @@ static bool canUseShortForm(Block *body, bool initFirst = false) {
15851589
// 2) The payload op must have the same number of operands as the number of
15861590
// block arguments.
15871591
if (payload.getNumOperands() == 0 ||
1588-
payload.getNumOperands() != body->getNumArguments())
1592+
payload.getNumOperands() != body->getNumArguments() - int(!mapInit))
15891593
return false;
15901594

15911595
// 3) If `initFirst` is true (e.g., for reduction ops), the init block
@@ -1603,7 +1607,8 @@ static bool canUseShortForm(Block *body, bool initFirst = false) {
16031607
}
16041608
} else {
16051609
for (const auto &[operand, bbArg] :
1606-
llvm::zip(payload.getOperands(), body->getArguments())) {
1610+
llvm::zip(payload.getOperands(),
1611+
body->getArguments().drop_back(int(!mapInit)))) {
16071612
if (bbArg != operand)
16081613
return false;
16091614
}
@@ -1635,7 +1640,8 @@ static void printShortForm(OpAsmPrinter &p, Operation *payloadOp) {
16351640

16361641
void MapOp::print(OpAsmPrinter &p) {
16371642
Block *mapper = getBody();
1638-
bool useShortForm = canUseShortForm(mapper);
1643+
bool useShortForm =
1644+
canUseShortForm(mapper, /*initFirst=*/false, /*mapInit*/ false);
16391645
if (useShortForm) {
16401646
printShortForm(p, &mapper->getOperations().front());
16411647
}

mlir/test/Dialect/Linalg/canonicalize.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1423,7 +1423,7 @@ func.func @transpose_buffer(%input: memref<?xf32>,
14231423
func.func @recursive_effect(%arg : tensor<1xf32>) {
14241424
%init = arith.constant dense<0.0> : tensor<1xf32>
14251425
%mapped = linalg.map ins(%arg:tensor<1xf32>) outs(%init :tensor<1xf32>)
1426-
(%in : f32) {
1426+
(%in : f32, %out: f32) {
14271427
vector.print %in : f32
14281428
linalg.yield %in : f32
14291429
}

mlir/test/Dialect/Linalg/invalid.mlir

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -681,7 +681,7 @@ func.func @map_binary_wrong_yield_operands(
681681
%add = linalg.map
682682
ins(%lhs, %rhs : tensor<64xf32>, tensor<64xf32>)
683683
outs(%init:tensor<64xf32>)
684-
(%lhs_elem: f32, %rhs_elem: f32) {
684+
(%lhs_elem: f32, %rhs_elem: f32, %out: f32) {
685685
%0 = arith.addf %lhs_elem, %rhs_elem: f32
686686
// expected-error @+1{{'linalg.yield' op expected number of yield values (2) to match the number of inits / outs operands of the enclosing LinalgOp (1)}}
687687
linalg.yield %0, %0: f32, f32
@@ -694,11 +694,11 @@ func.func @map_binary_wrong_yield_operands(
694694
func.func @map_input_mapper_arity_mismatch(
695695
%lhs: tensor<64xf32>, %rhs: tensor<64xf32>, %init: tensor<64xf32>)
696696
-> tensor<64xf32> {
697-
// expected-error@+1{{'linalg.map' op expects number of operands to match the arity of mapper, but got: 2 and 3}}
697+
// expected-error@+1{{'linalg.map' op expects number of operands to match the arity of mapper, but got: 3 and 4}}
698698
%add = linalg.map
699699
ins(%lhs, %rhs : tensor<64xf32>, tensor<64xf32>)
700700
outs(%init:tensor<64xf32>)
701-
(%lhs_elem: f32, %rhs_elem: f32, %extra_elem: f32) {
701+
(%lhs_elem: f32, %rhs_elem: f32, %out: f32, %extra_elem: f32) {
702702
%0 = arith.addf %lhs_elem, %rhs_elem: f32
703703
linalg.yield %0: f32
704704
}
@@ -714,7 +714,7 @@ func.func @map_input_mapper_type_mismatch(
714714
%add = linalg.map
715715
ins(%lhs, %rhs : tensor<64xf32>, tensor<64xf32>)
716716
outs(%init:tensor<64xf32>)
717-
(%lhs_elem: f64, %rhs_elem: f64) {
717+
(%lhs_elem: f64, %rhs_elem: f64, %out: f32) {
718718
%0 = arith.addf %lhs_elem, %rhs_elem: f64
719719
linalg.yield %0: f64
720720
}
@@ -730,7 +730,7 @@ func.func @map_input_output_shape_mismatch(
730730
%add = linalg.map
731731
ins(%lhs, %rhs : tensor<64x64xf32>, tensor<64x64xf32>)
732732
outs(%init:tensor<32xf32>)
733-
(%lhs_elem: f32, %rhs_elem: f32) {
733+
(%lhs_elem: f32, %rhs_elem: f32, %out: f32) {
734734
%0 = arith.addf %lhs_elem, %rhs_elem: f32
735735
linalg.yield %0: f32
736736
}

mlir/test/Dialect/Linalg/one-shot-bufferize.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -339,7 +339,7 @@ func.func @map_binary(%lhs: tensor<64xf32>, %rhs: tensor<64xf32>,
339339
%add = linalg.map
340340
ins(%lhs, %rhs: tensor<64xf32>, tensor<64xf32>)
341341
outs(%init:tensor<64xf32>)
342-
(%lhs_elem: f32, %rhs_elem: f32) {
342+
(%lhs_elem: f32, %rhs_elem: f32, %out: f32) {
343343
%0 = arith.addf %lhs_elem, %rhs_elem: f32
344344
linalg.yield %0: f32
345345
}

mlir/test/Dialect/Linalg/roundtrip.mlir

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -341,15 +341,15 @@ func.func @mixed_parallel_reduced_results(%arg0 : tensor<?x?x?xf32>,
341341
func.func @map_no_inputs(%init: tensor<64xf32>) -> tensor<64xf32> {
342342
%add = linalg.map
343343
outs(%init:tensor<64xf32>)
344-
() {
344+
(%out: f32) {
345345
%0 = arith.constant 0.0: f32
346346
linalg.yield %0: f32
347347
}
348348
func.return %add : tensor<64xf32>
349349
}
350350
// CHECK-LABEL: func @map_no_inputs
351351
// CHECK: linalg.map outs
352-
// CHECK-NEXT: () {
352+
// CHECK-NEXT: (%[[OUT:.*]]: f32) {
353353
// CHECK-NEXT: arith.constant
354354
// CHECK-NEXT: linalg.yield
355355
// CHECK-NEXT: }
@@ -361,7 +361,7 @@ func.func @map_binary(%lhs: tensor<64xf32>, %rhs: tensor<64xf32>,
361361
%add = linalg.map
362362
ins(%lhs, %rhs: tensor<64xf32>, tensor<64xf32>)
363363
outs(%init:tensor<64xf32>)
364-
(%lhs_elem: f32, %rhs_elem: f32) {
364+
(%lhs_elem: f32, %rhs_elem: f32, %out: f32) {
365365
%0 = arith.addf %lhs_elem, %rhs_elem: f32
366366
linalg.yield %0: f32
367367
}
@@ -378,7 +378,7 @@ func.func @map_binary_memref(%lhs: memref<64xf32>, %rhs: memref<64xf32>,
378378
linalg.map
379379
ins(%lhs, %rhs: memref<64xf32>, memref<64xf32>)
380380
outs(%init:memref<64xf32>)
381-
(%lhs_elem: f32, %rhs_elem: f32) {
381+
(%lhs_elem: f32, %rhs_elem: f32, %out: f32) {
382382
%0 = arith.addf %lhs_elem, %rhs_elem: f32
383383
linalg.yield %0: f32
384384
}
@@ -393,7 +393,7 @@ func.func @map_unary(%input: tensor<64xf32>, %init: tensor<64xf32>) -> tensor<64
393393
%abs = linalg.map
394394
ins(%input:tensor<64xf32>)
395395
outs(%init:tensor<64xf32>)
396-
(%input_elem: f32) {
396+
(%input_elem: f32, %out: f32) {
397397
%0 = math.absf %input_elem: f32
398398
linalg.yield %0: f32
399399
}
@@ -408,7 +408,7 @@ func.func @map_unary_memref(%input: memref<64xf32>, %init: memref<64xf32>) {
408408
linalg.map
409409
ins(%input:memref<64xf32>)
410410
outs(%init:memref<64xf32>)
411-
(%input_elem: f32) {
411+
(%input_elem: f32, %out: f32) {
412412
%0 = math.absf %input_elem: f32
413413
linalg.yield %0: f32
414414
}
@@ -604,7 +604,7 @@ func.func @map_arith_with_attr(%lhs: tensor<64xf32>, %rhs: tensor<64xf32>,
604604
%add = linalg.map
605605
ins(%lhs, %rhs: tensor<64xf32>, tensor<64xf32>)
606606
outs(%init:tensor<64xf32>)
607-
(%lhs_elem: f32, %rhs_elem: f32) {
607+
(%lhs_elem: f32, %rhs_elem: f32, %out: f32) {
608608
%0 = arith.addf %lhs_elem, %rhs_elem fastmath<fast> : f32
609609
linalg.yield %0: f32
610610
}
@@ -622,7 +622,7 @@ func.func @map_arith_with_attr(%lhs: tensor<64xf32>, %rhs: tensor<64xf32>,
622622

623623
func.func @map_not_short_form_compatible(%lhs: tensor<1x32xf32>, %rhs: tensor<1x32xf32>, %init: tensor<1x32xf32>) -> tensor<1x32xf32> {
624624
%mapped = linalg.map ins(%lhs, %rhs : tensor<1x32xf32>, tensor<1x32xf32>) outs(%init : tensor<1x32xf32>)
625-
(%in_1: f32, %in_2: f32) {
625+
(%in_1: f32, %in_2: f32, %out: f32) {
626626
%1 = arith.maximumf %in_1, %in_2 : f32
627627
linalg.yield %in_1 : f32
628628
}
@@ -634,7 +634,7 @@ func.func @map_not_short_form_compatible(%lhs: tensor<1x32xf32>, %rhs: tensor<1x
634634
// CHECK-NOT: linalg.map { arith.maximumf } ins(%[[LHS]] : tensor<1x32xf32>
635635
// CHECK: linalg.map ins(%[[LHS]], %[[RHS]] : tensor<1x32xf32>, tensor<1x32xf32>)
636636
// CHECK-SAME: outs(%[[INIT]] : tensor<1x32xf32>)
637-
// CHECK-NEXT: (%[[IN1:.*]]: f32, %[[IN2:.*]]: f32) {
637+
// CHECK-NEXT: (%[[IN1:.*]]: f32, %[[IN2:.*]]: f32, %[[OUT:.*]]: f32) {
638638
// CHECK-NEXT: %[[MAX_RESULT:.*]] = arith.maximumf %[[IN1]], %[[IN2]] : f32
639639
// CHECK-NEXT: linalg.yield %[[IN1]] : f32
640640
// CHECK-NEXT: }

mlir/test/Dialect/Linalg/vectorization/linalg-ops-with-patterns.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -381,7 +381,7 @@ func.func @vectorize_map(%arg0: memref<64xf32>,
381381
%arg1: memref<64xf32>, %arg2: memref<64xf32>) {
382382
linalg.map ins(%arg0, %arg1 : memref<64xf32>, memref<64xf32>)
383383
outs(%arg2 : memref<64xf32>)
384-
(%in: f32, %in_0: f32) {
384+
(%in: f32, %in_0: f32, %out: f32) {
385385
%0 = arith.addf %in, %in_0 : f32
386386
linalg.yield %0 : f32
387387
}

0 commit comments

Comments
 (0)