@@ -188,30 +188,24 @@ func.func @vecaddf_bf16(%arg0: vector<16xbf16>, %arg1: vector<16xbf16>) -> vecto
188188 return %0 : vector <16 xbf16 >
189189}
190190
191+ // Use llvm-aie to lower arith.subf on vectors.
191192// CHECK-LABEL: func @vecsubf_f32(
192193// CHECK-SAME: %[[LHS:.*]]: vector<16xf32>,
193194// CHECK-SAME: %[[RHS:.*]]: vector<16xf32>)
194195func.func @vecsubf_f32 (%arg0: vector <16 xf32 >, %arg1: vector <16 xf32 >) -> vector <16 xf32 > {
195- // CHECK: %[[LCAST:.*]] = aievec.cast %[[LHS]] {isResAcc = true} : vector<16xf32>, vector<16xf32>
196- // CHECK: %[[RCAST:.*]] = aievec.cast %[[RHS]] {isResAcc = true} : vector<16xf32>, vector<16xf32>
197- // CHECK: %[[SUB:.*]] = aievec.sub_elem %[[LCAST]], %[[RCAST:.*]] : vector<16xf32>
198- // CHECK: %[[CAST:.*]] = aievec.cast %[[SUB]] {isResAcc = false} : vector<16xf32>, vector<16xf32>
196+ // CHECK: %[[SUB:.*]] = arith.subf %[[LHS]], %[[RHS]] : vector<16xf32>
199197 %0 = arith.subf %arg0 , %arg1 : vector <16 xf32 >
200- // CHECK: return %[[CAST ]] : vector<16xf32>
198+ // CHECK: return %[[SUB ]] : vector<16xf32>
201199 return %0 : vector <16 xf32 >
202200}
203201
204202// CHECK-LABEL: func @vecsubf_bf16(
205203// CHECK-SAME: %[[LHS:.*]]: vector<16xbf16>,
206204// CHECK-SAME: %[[RHS:.*]]: vector<16xbf16>)
207205func.func @vecsubf_bf16 (%arg0: vector <16 xbf16 >, %arg1: vector <16 xbf16 >) -> vector <16 xbf16 > {
208- // CHECK: %[[C0:.*]] = arith.constant 0 : i32
209- // CHECK: %[[LUPS:.*]] = aievec.ups %[[LHS]] {shift = 0 : i8} : vector<16xbf16>, vector<16xf32>
210- // CHECK: %[[RUPS:.*]] = aievec.ups %[[RHS]] {shift = 0 : i8} : vector<16xbf16>, vector<16xf32>
211- // CHECK: %[[SUB:.*]] = aievec.sub_elem %[[LUPS]], %[[RUPS]] : vector<16xf32>
212- // CHECK: %[[SRS:.*]] = aievec.srs %[[SUB]], %[[C0]] : vector<16xf32>, i32, vector<16xbf16>
206+ // CHECK: %[[SUB:.*]] = arith.subf %[[LHS]], %[[RHS]] : vector<16xbf16>
213207 %0 = arith.subf %arg0 , %arg1 : vector <16 xbf16 >
214- // CHECK: return %[[SRS ]] : vector<16xbf16>
208+ // CHECK: return %[[SUB ]] : vector<16xbf16>
215209 return %0 : vector <16 xbf16 >
216210}
217211
0 commit comments