From 2e1e0f1350c6051e3eb09260546230305f8b87fc Mon Sep 17 00:00:00 2001 From: Thomas Ziereis Date: Fri, 8 Nov 2024 14:54:19 +0100 Subject: [PATCH 1/2] add uitofp rewrite --- .../Transforms/VectorEmulateNarrowType.cpp | 6 ++-- .../Vector/vector-rewrite-narrow-types.mlir | 30 ++++++++++++++++++- 2 files changed, 33 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp index 58841f29698e0..76ddaa2df5a9d 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp @@ -1452,8 +1452,10 @@ void vector::populateVectorNarrowTypeRewritePatterns( RewriteAlignedSubByteIntExt, RewriteAlignedSubByteIntTrunc>(patterns.getContext(), benefit.getBenefit() + 1); - patterns.add>( - patterns.getContext(), benefit.getBenefit() + 1); + patterns + .add, + RewriteAlignedSubByteIntExt>( + patterns.getContext(), benefit.getBenefit() + 1); } void vector::populateVectorTransposeNarrowTypeRewritePatterns( diff --git a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir index 84aaa9c61200b..75e46a79600d0 100644 --- a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir +++ b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir @@ -262,6 +262,34 @@ func.func @aligned_sitofp_2d(%a: vector<8x32xi4>) -> vector<8x32xf32> { return %0 : vector<8x32xf32> } +// CHECK-LABEL: func.func @aligned_uitofp( +func.func @aligned_uitofp(%a: vector<8xi4>) -> vector<8xf32> { +// CHECK-SAME: %[[IN:.*]]: vector<8xi4>) -> vector<8xf32> { +// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8> +// CHECK: %[[LOWBITS_MASK:.*]] = arith.constant dense<15> : vector<4xi8> +// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi4> to vector<4xi8> +// CHECK: %[[LOW:.*]] = arith.andi %[[BITCAST]], %[[LOWBITS_MASK]] : vector<4xi8> +// CHECK: %[[HIGH:.*]] = arith.shrui %[[BITCAST]], %[[I4_BITS]] : vector<4xi8> +// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<4xi8> +// CHECK: %[[F32:.*]] = arith.uitofp %[[INTERLEAVE]] : vector<8xi8> to vector<8xf32> + %0 = arith.uitofp %a : vector<8xi4> to vector<8xf32> + return %0 : vector<8xf32> +} + +// CHECK-LABEL: func.func @aligned_uitofp_2d( +func.func @aligned_uitofp_2d(%a: vector<8x32xi4>) -> vector<8x32xf32> { +// CHECK-SAME: %[[IN:.*]]: vector<8x32xi4>) -> vector<8x32xf32> { +// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<8x16xi8> +// CHECK: %[[LOWBITS_MASK:.*]] = arith.constant dense<15> : vector<8x16xi8> +// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8x32xi4> to vector<8x16xi8> +// CHECK: %[[LOW:.*]] = arith.andi %[[BITCAST]], %[[LOWBITS_MASK]] : vector<8x16xi8> +// CHECK: %[[HIGH:.*]] = arith.shrui %[[BITCAST]], %[[I4_BITS]] : vector<8x16xi8> +// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<8x16xi8> +// CHECK: %[[F32:.*]] = arith.uitofp %[[INTERLEAVE]] : vector<8x32xi8> to vector<8x32xf32> + %0 = arith.uitofp %a : vector<8x32xi4> to vector<8x32xf32> + return %0 : vector<8x32xf32> +} + // CHECK-LABEL: func.func @aligned_trunci( func.func @aligned_trunci(%a: vector<8xi32>) -> vector<8xi4> { // CHECK-SAME: %[[IN:.*]]: vector<8xi32>) -> vector<8xi4> { @@ -314,7 +342,7 @@ func.func @aligned_trunci_nd(%a: vector<3x8x32xi32>) -> vector<3x8x32xi4> { // CHECK: %[[ZEROED_LOW:.*]] = arith.andi %[[LOW]], %[[I4_MASK]] : vector<3x8x16xi8> // CHECK: %[[SHL_HIGH:.*]] = arith.shli %[[HIGH]], %[[LEFT_SHIFT_BITS]] : vector<3x8x16xi8> // CHECK: %[[MERGED:.*]] = arith.ori %[[ZEROED_LOW]], %[[SHL_HIGH]] : vector<3x8x16xi8> - // CHECK: %[[I4:.*]] = vector.bitcast %[[MERGED]] : vector<3x8x16xi8> to vector<3x8x32xi4> + // CHECK: %[[I4:.*]] = vector.bitcast %[[MERGED]] : vector<3x8x16xi8> to vector<3x8x32xi4> %0 = arith.trunci %a : vector<3x8x32xi32> to vector<3x8x32xi4> return %0 : vector<3x8x32xi4> } From ecbad796fc678bbdd269164fa9cf423a6dd2b88f Mon Sep 17 00:00:00 2001 From: Thomas Ziereis Date: Mon, 11 Nov 2024 14:14:26 +0100 Subject: [PATCH 2/2] renaming tests --- .../Vector/vector-rewrite-narrow-types.mlir | 212 +++++++++--------- 1 file changed, 107 insertions(+), 105 deletions(-) diff --git a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir index 75e46a79600d0..210025e30d7db 100644 --- a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir +++ b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir @@ -193,36 +193,8 @@ func.func @f3ext(%a: vector<5xi8>) -> vector<8xi17> { return %1 : vector<8xi17> } -// CHECK-LABEL: func.func @aligned_extsi( -func.func @aligned_extsi(%a: vector<8xi4>) -> vector<8xi32> { -// CHECK-SAME: %[[IN:.*]]: vector<8xi4>) -> vector<8xi32> { -// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8> -// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi4> to vector<4xi8> -// CHECK: %[[SHL_LOW:.*]] = arith.shli %[[BITCAST]], %[[I4_BITS]] : vector<4xi8> -// CHECK: %[[LOW:.*]] = arith.shrsi %[[SHL_LOW]], %[[I4_BITS]] : vector<4xi8> -// CHECK: %[[HIGH:.*]] = arith.shrsi %[[BITCAST]], %[[I4_BITS]] : vector<4xi8> -// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<4xi8> -// CHECK: %[[I32:.*]] = arith.extsi %[[INTERLEAVE]] : vector<8xi8> to vector<8xi32> - %0 = arith.extsi %a : vector<8xi4> to vector<8xi32> - return %0 : vector<8xi32> -} - -// CHECK-LABEL: func.func @aligned_extsi_2d( -func.func @aligned_extsi_2d(%a: vector<8x32xi4>) -> vector<8x32xi32> { -// CHECK-SAME: %[[IN:.*]]: vector<8x32xi4>) -> vector<8x32xi32> { -// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<8x16xi8> -// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8x32xi4> to vector<8x16xi8> -// CHECK: %[[SHL_LOW:.*]] = arith.shli %[[BITCAST]], %[[I4_BITS]] : vector<8x16xi8> -// CHECK: %[[LOW:.*]] = arith.shrsi %[[SHL_LOW]], %[[I4_BITS]] : vector<8x16xi8> -// CHECK: %[[HIGH:.*]] = arith.shrsi %[[BITCAST]], %[[I4_BITS]] : vector<8x16xi8> -// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<8x16xi8> -// CHECK: %[[I32:.*]] = arith.extsi %[[INTERLEAVE]] : vector<8x32xi8> to vector<8x32xi32> - %0 = arith.extsi %a : vector<8x32xi4> to vector<8x32xi32> - return %0 : vector<8x32xi32> -} - -// CHECK-LABEL: func.func @aligned_extsi_base_case( -func.func @aligned_extsi_base_case(%a: vector<8xi4>) -> vector<8xi8> { +// CHECK-LABEL: func.func @aligned_extsi_i4_to_i8( +func.func @aligned_extsi_i4_to_i8(%a: vector<8xi4>) -> vector<8xi8> { // CHECK-SAME: %[[IN:.*]]: vector<8xi4>) -> vector<8xi8> { // CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8> // CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi4> to vector<4xi8> @@ -234,88 +206,61 @@ func.func @aligned_extsi_base_case(%a: vector<8xi4>) -> vector<8xi8> { return %0 : vector<8xi8> } -// CHECK-LABEL: func.func @aligned_sitofp( -func.func @aligned_sitofp(%a: vector<8xi4>) -> vector<8xf32> { -// CHECK-SAME: %[[IN:.*]]: vector<8xi4>) -> vector<8xf32> { +// CHECK-LABEL: func.func @aligned_extsi_i4_to_i32( +func.func @aligned_extsi_i4_to_i32(%a: vector<8xi4>) -> vector<8xi32> { +// CHECK-SAME: %[[IN:.*]]: vector<8xi4>) -> vector<8xi32> { // CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8> // CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi4> to vector<4xi8> // CHECK: %[[SHL_LOW:.*]] = arith.shli %[[BITCAST]], %[[I4_BITS]] : vector<4xi8> // CHECK: %[[LOW:.*]] = arith.shrsi %[[SHL_LOW]], %[[I4_BITS]] : vector<4xi8> // CHECK: %[[HIGH:.*]] = arith.shrsi %[[BITCAST]], %[[I4_BITS]] : vector<4xi8> // CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<4xi8> -// CHECK: %[[F32:.*]] = arith.sitofp %[[INTERLEAVE]] : vector<8xi8> to vector<8xf32> - %0 = arith.sitofp %a : vector<8xi4> to vector<8xf32> - return %0 : vector<8xf32> +// CHECK: %[[I32:.*]] = arith.extsi %[[INTERLEAVE]] : vector<8xi8> to vector<8xi32> + %0 = arith.extsi %a : vector<8xi4> to vector<8xi32> + return %0 : vector<8xi32> } -// CHECK-LABEL: func.func @aligned_sitofp_2d( -func.func @aligned_sitofp_2d(%a: vector<8x32xi4>) -> vector<8x32xf32> { -// CHECK-SAME: %[[IN:.*]]: vector<8x32xi4>) -> vector<8x32xf32> { +// CHECK-LABEL: func.func @aligned_extsi_2d( +func.func @aligned_extsi_2d(%a: vector<8x32xi4>) -> vector<8x32xi32> { +// CHECK-SAME: %[[IN:.*]]: vector<8x32xi4>) -> vector<8x32xi32> { // CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<8x16xi8> // CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8x32xi4> to vector<8x16xi8> // CHECK: %[[SHL_LOW:.*]] = arith.shli %[[BITCAST]], %[[I4_BITS]] : vector<8x16xi8> // CHECK: %[[LOW:.*]] = arith.shrsi %[[SHL_LOW]], %[[I4_BITS]] : vector<8x16xi8> // CHECK: %[[HIGH:.*]] = arith.shrsi %[[BITCAST]], %[[I4_BITS]] : vector<8x16xi8> // CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<8x16xi8> -// CHECK: %[[F32:.*]] = arith.sitofp %[[INTERLEAVE]] : vector<8x32xi8> to vector<8x32xf32> - %0 = arith.sitofp %a : vector<8x32xi4> to vector<8x32xf32> - return %0 : vector<8x32xf32> -} - -// CHECK-LABEL: func.func @aligned_uitofp( -func.func @aligned_uitofp(%a: vector<8xi4>) -> vector<8xf32> { -// CHECK-SAME: %[[IN:.*]]: vector<8xi4>) -> vector<8xf32> { -// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8> -// CHECK: %[[LOWBITS_MASK:.*]] = arith.constant dense<15> : vector<4xi8> -// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi4> to vector<4xi8> -// CHECK: %[[LOW:.*]] = arith.andi %[[BITCAST]], %[[LOWBITS_MASK]] : vector<4xi8> -// CHECK: %[[HIGH:.*]] = arith.shrui %[[BITCAST]], %[[I4_BITS]] : vector<4xi8> -// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<4xi8> -// CHECK: %[[F32:.*]] = arith.uitofp %[[INTERLEAVE]] : vector<8xi8> to vector<8xf32> - %0 = arith.uitofp %a : vector<8xi4> to vector<8xf32> - return %0 : vector<8xf32> +// CHECK: %[[I32:.*]] = arith.extsi %[[INTERLEAVE]] : vector<8x32xi8> to vector<8x32xi32> + %0 = arith.extsi %a : vector<8x32xi4> to vector<8x32xi32> + return %0 : vector<8x32xi32> } -// CHECK-LABEL: func.func @aligned_uitofp_2d( -func.func @aligned_uitofp_2d(%a: vector<8x32xi4>) -> vector<8x32xf32> { -// CHECK-SAME: %[[IN:.*]]: vector<8x32xi4>) -> vector<8x32xf32> { -// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<8x16xi8> -// CHECK: %[[LOWBITS_MASK:.*]] = arith.constant dense<15> : vector<8x16xi8> -// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8x32xi4> to vector<8x16xi8> -// CHECK: %[[LOW:.*]] = arith.andi %[[BITCAST]], %[[LOWBITS_MASK]] : vector<8x16xi8> -// CHECK: %[[HIGH:.*]] = arith.shrui %[[BITCAST]], %[[I4_BITS]] : vector<8x16xi8> -// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<8x16xi8> -// CHECK: %[[F32:.*]] = arith.uitofp %[[INTERLEAVE]] : vector<8x32xi8> to vector<8x32xf32> - %0 = arith.uitofp %a : vector<8x32xi4> to vector<8x32xf32> - return %0 : vector<8x32xf32> -} -// CHECK-LABEL: func.func @aligned_trunci( -func.func @aligned_trunci(%a: vector<8xi32>) -> vector<8xi4> { -// CHECK-SAME: %[[IN:.*]]: vector<8xi32>) -> vector<8xi4> { +// CHECK-LABEL: func.func @aligned_trunci_i8_to_i4( +func.func @aligned_trunci_i8_to_i4(%a: vector<8xi8>) -> vector<8xi4> { +// CHECK-SAME: %[[IN:.*]]: vector<8xi8>) -> vector<8xi4> { // CHECK-DAG: %[[LOW_MASK:.*]] = arith.constant dense<15> : vector<4xi8> // CHECK-DAG: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8> -// CHECK: %[[I8:.*]] = arith.trunci %[[IN]] : vector<8xi32> to vector<8xi8> -// CHECK: %[[LOW:.*]], %[[HIGH:.*]] = vector.deinterleave %[[I8]] : vector<8xi8> -> vector<4xi8> +// CHECK: %[[LOW:.*]], %[[HIGH:.*]] = vector.deinterleave %[[IN]] : vector<8xi8> -> vector<4xi8> // CHECK: %[[ZEROED_LOW:.*]] = arith.andi %[[LOW]], %[[LOW_MASK]] : vector<4xi8> // CHECK: %[[SHL_HIGH:.*]] = arith.shli %[[HIGH]], %[[I4_BITS]] : vector<4xi8> // CHECK: %[[MERGED:.*]] = arith.ori %[[ZEROED_LOW]], %[[SHL_HIGH]] : vector<4xi8> // CHECK: %[[I4:.*]] = vector.bitcast %[[MERGED]] : vector<4xi8> to vector<8xi4> - %0 = arith.trunci %a : vector<8xi32> to vector<8xi4> + %0 = arith.trunci %a : vector<8xi8> to vector<8xi4> return %0 : vector<8xi4> } -// CHECK-LABEL: func.func @aligned_trunci_base_case( -func.func @aligned_trunci_base_case(%a: vector<8xi8>) -> vector<8xi4> { -// CHECK-SAME: %[[IN:.*]]: vector<8xi8>) -> vector<8xi4> { +// CHECK-LABEL: func.func @aligned_trunci_i32_to_i4( +func.func @aligned_trunci_i32_to_i4(%a: vector<8xi32>) -> vector<8xi4> { +// CHECK-SAME: %[[IN:.*]]: vector<8xi32>) -> vector<8xi4> { // CHECK-DAG: %[[LOW_MASK:.*]] = arith.constant dense<15> : vector<4xi8> // CHECK-DAG: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8> -// CHECK: %[[LOW:.*]], %[[HIGH:.*]] = vector.deinterleave %[[IN]] : vector<8xi8> -> vector<4xi8> +// CHECK: %[[I8:.*]] = arith.trunci %[[IN]] : vector<8xi32> to vector<8xi8> +// CHECK: %[[LOW:.*]], %[[HIGH:.*]] = vector.deinterleave %[[I8]] : vector<8xi8> -> vector<4xi8> // CHECK: %[[ZEROED_LOW:.*]] = arith.andi %[[LOW]], %[[LOW_MASK]] : vector<4xi8> // CHECK: %[[SHL_HIGH:.*]] = arith.shli %[[HIGH]], %[[I4_BITS]] : vector<4xi8> // CHECK: %[[MERGED:.*]] = arith.ori %[[ZEROED_LOW]], %[[SHL_HIGH]] : vector<4xi8> // CHECK: %[[I4:.*]] = vector.bitcast %[[MERGED]] : vector<4xi8> to vector<8xi4> - %0 = arith.trunci %a : vector<8xi8> to vector<8xi4> + %0 = arith.trunci %a : vector<8xi32> to vector<8xi4> return %0 : vector<8xi4> } @@ -347,28 +292,21 @@ func.func @aligned_trunci_nd(%a: vector<3x8x32xi32>) -> vector<3x8x32xi4> { return %0 : vector<3x8x32xi4> } -// CHECK-LABEL: func.func @i4_transpose( -func.func @i4_transpose(%a: vector<8x16xi4>) -> vector<16x8xi4> { -// CHECK-SAME: %[[IN:.*]]: vector<8x16xi4>) -> vector<16x8xi4> { -// CHECK: %[[EXT:.*]] = vector.interleave -// CHECK: %[[TRANS:.*]] = vector.transpose %[[EXT]], [1, 0] : vector<8x16xi8> to vector<16x8xi8> -// CHECK: vector.deinterleave %[[TRANS]] : vector<16x8xi8> -> vector<16x4xi8> - %0 = vector.transpose %a, [1, 0] : vector<8x16xi4> to vector<16x8xi4> - return %0 : vector<16x8xi4> -} - -// CHECK-LABEL: func.func @i7_transpose( -func.func @i7_transpose(%a: vector<8x16xi7>) -> vector<16x8xi7> { -// CHECK-SAME: %[[IN:.*]]: vector<8x16xi7>) -> vector<16x8xi7> { -// CHECK: %[[EXT:.*]] = arith.extsi %[[IN]] : vector<8x16xi7> to vector<8x16xi8> -// CHECK: %[[TRANS:.*]] = vector.transpose %[[EXT]], [1, 0] : vector<8x16xi8> to vector<16x8xi8> -// CHECK: %[[TRUNC:.*]] = arith.trunci %[[TRANS]] : vector<16x8xi8> to vector<16x8xi7> - %0 = vector.transpose %a, [1, 0] : vector<8x16xi7> to vector<16x8xi7> - return %0 : vector<16x8xi7> +// CHECK-LABEL: func.func @aligned_extui_i4_to_i8( +func.func @aligned_extui_i4_to_i8(%a: vector<8xi4>) -> vector<8xi8> { +// CHECK-SAME: %[[IN:.*]]: vector<8xi4>) -> vector<8xi8> { +// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8> +// CHECK: %[[LOWBITS_MASK:.*]] = arith.constant dense<15> : vector<4xi8> +// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi4> to vector<4xi8> +// CHECK: %[[LOW:.*]] = arith.andi %[[BITCAST]], %[[LOWBITS_MASK]] : vector<4xi8> +// CHECK: %[[HIGH:.*]] = arith.shrui %[[BITCAST]], %[[I4_BITS]] : vector<4xi8> +// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<4xi8> + %0 = arith.extui %a : vector<8xi4> to vector<8xi8> + return %0 : vector<8xi8> } -// CHECK-LABEL: func.func @aligned_extui( -func.func @aligned_extui(%a: vector<8xi4>) -> vector<8xi32> { +// CHECK-LABEL: func.func @aligned_extui_i4_to_i32( +func.func @aligned_extui_i4_to_i32(%a: vector<8xi4>) -> vector<8xi32> { // CHECK-SAME: %[[IN:.*]]: vector<8xi4>) -> vector<8xi32> { // CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8> // CHECK: %[[LOWBITS_MASK:.*]] = arith.constant dense<15> : vector<4xi8> @@ -395,19 +333,83 @@ func.func @aligned_extui_2d(%a: vector<8x32xi4>) -> vector<8x32xi32> { return %0 : vector<8x32xi32> } -// CHECK-LABEL: func.func @aligned_extui_base_case( -func.func @aligned_extui_base_case(%a: vector<8xi4>) -> vector<8xi8> { -// CHECK-SAME: %[[IN:.*]]: vector<8xi4>) -> vector<8xi8> { +// CHECK-LABEL: func.func @aligned_sitofp( +func.func @aligned_sitofp(%a: vector<8xi4>) -> vector<8xf32> { +// CHECK-SAME: %[[IN:.*]]: vector<8xi4>) -> vector<8xf32> { +// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8> +// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi4> to vector<4xi8> +// CHECK: %[[SHL_LOW:.*]] = arith.shli %[[BITCAST]], %[[I4_BITS]] : vector<4xi8> +// CHECK: %[[LOW:.*]] = arith.shrsi %[[SHL_LOW]], %[[I4_BITS]] : vector<4xi8> +// CHECK: %[[HIGH:.*]] = arith.shrsi %[[BITCAST]], %[[I4_BITS]] : vector<4xi8> +// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<4xi8> +// CHECK: %[[F32:.*]] = arith.sitofp %[[INTERLEAVE]] : vector<8xi8> to vector<8xf32> + %0 = arith.sitofp %a : vector<8xi4> to vector<8xf32> + return %0 : vector<8xf32> +} + +// CHECK-LABEL: func.func @aligned_sitofp_2d( +func.func @aligned_sitofp_2d(%a: vector<8x32xi4>) -> vector<8x32xf32> { +// CHECK-SAME: %[[IN:.*]]: vector<8x32xi4>) -> vector<8x32xf32> { +// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<8x16xi8> +// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8x32xi4> to vector<8x16xi8> +// CHECK: %[[SHL_LOW:.*]] = arith.shli %[[BITCAST]], %[[I4_BITS]] : vector<8x16xi8> +// CHECK: %[[LOW:.*]] = arith.shrsi %[[SHL_LOW]], %[[I4_BITS]] : vector<8x16xi8> +// CHECK: %[[HIGH:.*]] = arith.shrsi %[[BITCAST]], %[[I4_BITS]] : vector<8x16xi8> +// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<8x16xi8> +// CHECK: %[[F32:.*]] = arith.sitofp %[[INTERLEAVE]] : vector<8x32xi8> to vector<8x32xf32> + %0 = arith.sitofp %a : vector<8x32xi4> to vector<8x32xf32> + return %0 : vector<8x32xf32> +} + +// CHECK-LABEL: func.func @aligned_uitofp( +func.func @aligned_uitofp(%a: vector<8xi4>) -> vector<8xf32> { +// CHECK-SAME: %[[IN:.*]]: vector<8xi4>) -> vector<8xf32> { // CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8> // CHECK: %[[LOWBITS_MASK:.*]] = arith.constant dense<15> : vector<4xi8> // CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi4> to vector<4xi8> // CHECK: %[[LOW:.*]] = arith.andi %[[BITCAST]], %[[LOWBITS_MASK]] : vector<4xi8> // CHECK: %[[HIGH:.*]] = arith.shrui %[[BITCAST]], %[[I4_BITS]] : vector<4xi8> // CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<4xi8> - %0 = arith.extui %a : vector<8xi4> to vector<8xi8> - return %0 : vector<8xi8> +// CHECK: %[[F32:.*]] = arith.uitofp %[[INTERLEAVE]] : vector<8xi8> to vector<8xf32> + %0 = arith.uitofp %a : vector<8xi4> to vector<8xf32> + return %0 : vector<8xf32> } +// CHECK-LABEL: func.func @aligned_uitofp_2d( +func.func @aligned_uitofp_2d(%a: vector<8x32xi4>) -> vector<8x32xf32> { +// CHECK-SAME: %[[IN:.*]]: vector<8x32xi4>) -> vector<8x32xf32> { +// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<8x16xi8> +// CHECK: %[[LOWBITS_MASK:.*]] = arith.constant dense<15> : vector<8x16xi8> +// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8x32xi4> to vector<8x16xi8> +// CHECK: %[[LOW:.*]] = arith.andi %[[BITCAST]], %[[LOWBITS_MASK]] : vector<8x16xi8> +// CHECK: %[[HIGH:.*]] = arith.shrui %[[BITCAST]], %[[I4_BITS]] : vector<8x16xi8> +// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<8x16xi8> +// CHECK: %[[F32:.*]] = arith.uitofp %[[INTERLEAVE]] : vector<8x32xi8> to vector<8x32xf32> + %0 = arith.uitofp %a : vector<8x32xi4> to vector<8x32xf32> + return %0 : vector<8x32xf32> +} + +// CHECK-LABEL: func.func @i4_transpose( +func.func @i4_transpose(%a: vector<8x16xi4>) -> vector<16x8xi4> { +// CHECK-SAME: %[[IN:.*]]: vector<8x16xi4>) -> vector<16x8xi4> { +// CHECK: %[[EXT:.*]] = vector.interleave +// CHECK: %[[TRANS:.*]] = vector.transpose %[[EXT]], [1, 0] : vector<8x16xi8> to vector<16x8xi8> +// CHECK: vector.deinterleave %[[TRANS]] : vector<16x8xi8> -> vector<16x4xi8> + %0 = vector.transpose %a, [1, 0] : vector<8x16xi4> to vector<16x8xi4> + return %0 : vector<16x8xi4> +} + +// CHECK-LABEL: func.func @i7_transpose( +func.func @i7_transpose(%a: vector<8x16xi7>) -> vector<16x8xi7> { +// CHECK-SAME: %[[IN:.*]]: vector<8x16xi7>) -> vector<16x8xi7> { +// CHECK: %[[EXT:.*]] = arith.extsi %[[IN]] : vector<8x16xi7> to vector<8x16xi8> +// CHECK: %[[TRANS:.*]] = vector.transpose %[[EXT]], [1, 0] : vector<8x16xi8> to vector<16x8xi8> +// CHECK: %[[TRUNC:.*]] = arith.trunci %[[TRANS]] : vector<16x8xi8> to vector<16x8xi7> + %0 = vector.transpose %a, [1, 0] : vector<8x16xi7> to vector<16x8xi7> + return %0 : vector<16x8xi7> +} + + module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) { %f = transform.structured.match ops{["func.func"]} in %module_op