@@ -79,4 +79,37 @@ func.func @map_ops(%in1: tensor<8xf32>, %in2: tensor<8xf32>) -> tensor<8xf32> {
7979//  CHECK-NEXT:     %[[ADD:.*]] = arith.addf %[[IN0]], %[[IN1]] 
8080//  CHECK-NEXT:     %[[SQRT:.*]] = math.sqrt %[[ADD]] 
8181//  CHECK-NEXT:     linalg.yield %[[SQRT]]  
82- //   CHECK-NOT:   linalg.generic 
82+ //   CHECK-NOT:   linalg.map 
83+ 
84+ // ----- 
85+ 
86+ func.func  @map_ops_mixed_types (%arg0:  tensor <8 xf32 >, %arg1:  tensor <8 xf32 >) -> tensor <8 xf32 > {
87+   %init  = tensor.empty () : tensor <8 xi1 >
88+   %initf  = tensor.empty () : tensor <8 xf32 >
89+   %0  = linalg.map  {math.sqrt } ins (%arg0  : tensor <8 xf32 >) outs (%initf  : tensor <8 xf32 >)
90+   %1  = linalg.map  {math.exp } ins (%arg1  : tensor <8 xf32 >) outs (%initf  : tensor <8 xf32 >)
91+   %2  = linalg.map  ins (%0 , %1  : tensor <8 xf32 >, tensor <8 xf32 >) outs  (%init  : tensor <8 xi1 >)
92+     (%in0  : f32 , %in1  : f32 ) {
93+       %cmp  = arith.cmpf  olt , %in0 , %in1  : f32 
94+       linalg.yield  %cmp  : i1 
95+   }
96+   %3  = linalg.map  { arith.select  } ins (%2 , %0 , %1  : tensor <8 xi1 >, tensor <8 xf32 >, tensor <8 xf32 >) outs (%initf  : tensor <8 xf32 >) 
97+   return  %3  : tensor <8 xf32 >
98+ }
99+ 
100+ // CHECK-LABEL: func @map_ops_mixed_types 
101+ //  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9]+]]: tensor<8xf32> 
102+ //  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9]+]]: tensor<8xf32> 
103+ //       CHECK:   %[[EMPTY:.+]] = tensor.empty() : tensor<8xf32> 
104+ //       CHECK:   %[[FUSED_OP:.+]] = linalg.generic 
105+ //  CHECK-SAME:       ins(%[[ARG0]], %[[ARG1]] : {{.*}}) outs(%[[EMPTY]] : 
106+ //  CHECK-NEXT:   ^bb0(%[[IN0:.*]]: f32, %[[IN1:.*]]: f32, %[[OUT:.*]]: f32): 
107+ //  CHECK-NEXT:     %[[EXP0:.*]] = math.exp %[[IN1]] 
108+ //  CHECK-NEXT:     %[[SQRT0:.*]] = math.sqrt %[[IN0]] 
109+ //  CHECK-NEXT:     %[[EXP1:.*]] = math.exp %[[IN1]] 
110+ //  CHECK-NEXT:     %[[SQRT1:.*]] = math.sqrt %[[IN0]] 
111+ //  CHECK-NEXT:     %[[CMP:.*]] = arith.cmpf olt, %[[SQRT1]], %[[EXP1]] 
112+ //  CHECK-NEXT:     %[[RES:.*]] = arith.select %[[CMP]], %[[SQRT0]], %[[EXP0]] 
113+ //  CHECK-NEXT:     linalg.yield %[[RES]]  
114+ //   CHECK-NOT:   linalg.map 
115+ 
0 commit comments