Skip to content

Commit 2e1e0f1

Browse files
committed
add uitofp rewrite
1 parent 3797daa commit 2e1e0f1

File tree

2 files changed

+33
-3
lines changed

2 files changed

+33
-3
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1452,8 +1452,10 @@ void vector::populateVectorNarrowTypeRewritePatterns(
14521452
RewriteAlignedSubByteIntExt<arith::SIToFPOp, /*isSigned=*/true>,
14531453
RewriteAlignedSubByteIntTrunc>(patterns.getContext(),
14541454
benefit.getBenefit() + 1);
1455-
patterns.add<RewriteAlignedSubByteIntExt<arith::ExtUIOp, /*isSigned=*/false>>(
1456-
patterns.getContext(), benefit.getBenefit() + 1);
1455+
patterns
1456+
.add<RewriteAlignedSubByteIntExt<arith::ExtUIOp, /*isSigned=*/false>,
1457+
RewriteAlignedSubByteIntExt<arith::UIToFPOp, /*isSigned=*/false>>(
1458+
patterns.getContext(), benefit.getBenefit() + 1);
14571459
}
14581460

14591461
void vector::populateVectorTransposeNarrowTypeRewritePatterns(

mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,34 @@ func.func @aligned_sitofp_2d(%a: vector<8x32xi4>) -> vector<8x32xf32> {
262262
return %0 : vector<8x32xf32>
263263
}
264264

265+
// CHECK-LABEL: func.func @aligned_uitofp(
266+
func.func @aligned_uitofp(%a: vector<8xi4>) -> vector<8xf32> {
267+
// CHECK-SAME: %[[IN:.*]]: vector<8xi4>) -> vector<8xf32> {
268+
// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8>
269+
// CHECK: %[[LOWBITS_MASK:.*]] = arith.constant dense<15> : vector<4xi8>
270+
// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi4> to vector<4xi8>
271+
// CHECK: %[[LOW:.*]] = arith.andi %[[BITCAST]], %[[LOWBITS_MASK]] : vector<4xi8>
272+
// CHECK: %[[HIGH:.*]] = arith.shrui %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
273+
// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<4xi8>
274+
// CHECK: %[[F32:.*]] = arith.uitofp %[[INTERLEAVE]] : vector<8xi8> to vector<8xf32>
275+
%0 = arith.uitofp %a : vector<8xi4> to vector<8xf32>
276+
return %0 : vector<8xf32>
277+
}
278+
279+
// CHECK-LABEL: func.func @aligned_uitofp_2d(
280+
func.func @aligned_uitofp_2d(%a: vector<8x32xi4>) -> vector<8x32xf32> {
281+
// CHECK-SAME: %[[IN:.*]]: vector<8x32xi4>) -> vector<8x32xf32> {
282+
// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<8x16xi8>
283+
// CHECK: %[[LOWBITS_MASK:.*]] = arith.constant dense<15> : vector<8x16xi8>
284+
// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8x32xi4> to vector<8x16xi8>
285+
// CHECK: %[[LOW:.*]] = arith.andi %[[BITCAST]], %[[LOWBITS_MASK]] : vector<8x16xi8>
286+
// CHECK: %[[HIGH:.*]] = arith.shrui %[[BITCAST]], %[[I4_BITS]] : vector<8x16xi8>
287+
// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<8x16xi8>
288+
// CHECK: %[[F32:.*]] = arith.uitofp %[[INTERLEAVE]] : vector<8x32xi8> to vector<8x32xf32>
289+
%0 = arith.uitofp %a : vector<8x32xi4> to vector<8x32xf32>
290+
return %0 : vector<8x32xf32>
291+
}
292+
265293
// CHECK-LABEL: func.func @aligned_trunci(
266294
func.func @aligned_trunci(%a: vector<8xi32>) -> vector<8xi4> {
267295
// CHECK-SAME: %[[IN:.*]]: vector<8xi32>) -> vector<8xi4> {
@@ -314,7 +342,7 @@ func.func @aligned_trunci_nd(%a: vector<3x8x32xi32>) -> vector<3x8x32xi4> {
314342
// CHECK: %[[ZEROED_LOW:.*]] = arith.andi %[[LOW]], %[[I4_MASK]] : vector<3x8x16xi8>
315343
// CHECK: %[[SHL_HIGH:.*]] = arith.shli %[[HIGH]], %[[LEFT_SHIFT_BITS]] : vector<3x8x16xi8>
316344
// CHECK: %[[MERGED:.*]] = arith.ori %[[ZEROED_LOW]], %[[SHL_HIGH]] : vector<3x8x16xi8>
317-
// CHECK: %[[I4:.*]] = vector.bitcast %[[MERGED]] : vector<3x8x16xi8> to vector<3x8x32xi4>
345+
// CHECK: %[[I4:.*]] = vector.bitcast %[[MERGED]] : vector<3x8x16xi8> to vector<3x8x32xi4>
318346
%0 = arith.trunci %a : vector<3x8x32xi32> to vector<3x8x32xi4>
319347
return %0 : vector<3x8x32xi4>
320348
}

0 commit comments

Comments
 (0)