Skip to content

Commit 68d203b

Browse files
committed
[mlir][Vector] Introduce poison in LowerVectorBitCast/Broadcast/Transpose
This PR continues with the introduction of poison as initialization vector, in this particular case, in LowerVectorBitCast, LowerVectorBroadcast and LowerVectorTranspose.
1 parent a0d86b2 commit 68d203b

File tree

7 files changed

+53
-75
lines changed

7 files changed

+53
-75
lines changed

mlir/lib/Dialect/Vector/Transforms/LowerVectorBitCast.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
//
1212
//===----------------------------------------------------------------------===//
1313

14+
#include "mlir/Dialect/UB/IR/UBOps.h"
1415
#include "mlir/Dialect/Vector/IR/VectorOps.h"
1516
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
1617
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
@@ -32,7 +33,7 @@ namespace {
3233
///
3334
/// Would be unrolled to:
3435
///
35-
/// %result = arith.constant dense<0> : vector<1x2x3x8xi32>
36+
/// %result = ub.poison : vector<1x2x3x8xi32>
3637
/// %0 = vector.extract %a[0, 0, 0] ─┐
3738
/// : vector<4xi64> from vector<1x2x3x4xi64> |
3839
/// %1 = vector.bitcast %0 | - Repeated 6x for
@@ -63,8 +64,7 @@ class UnrollBitCastOp final : public OpRewritePattern<vector::BitCastOp> {
6364
VectorType::get(shape, resultType.getElementType(), scalableDims);
6465

6566
Location loc = op.getLoc();
66-
Value result = rewriter.create<arith::ConstantOp>(
67-
loc, resultType, rewriter.getZeroAttr(resultType));
67+
Value result = rewriter.create<ub::PoisonOp>(loc, resultType);
6868
for (auto position : *unrollIterator) {
6969
Value extract =
7070
rewriter.create<vector::ExtractOp>(loc, op.getSource(), position);

mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -11,27 +11,16 @@
1111
//
1212
//===----------------------------------------------------------------------===//
1313

14-
#include "mlir/Dialect/Affine/IR/AffineOps.h"
15-
#include "mlir/Dialect/Arith/IR/Arith.h"
16-
#include "mlir/Dialect/Arith/Utils/Utils.h"
17-
#include "mlir/Dialect/Linalg/IR/Linalg.h"
1814
#include "mlir/Dialect/MemRef/IR/MemRef.h"
19-
#include "mlir/Dialect/SCF/IR/SCF.h"
20-
#include "mlir/Dialect/Tensor/IR/Tensor.h"
21-
#include "mlir/Dialect/Utils/IndexingUtils.h"
22-
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
15+
#include "mlir/Dialect/UB/IR/UBOps.h"
2316
#include "mlir/Dialect/Vector/IR/VectorOps.h"
2417
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
2518
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
2619
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
27-
#include "mlir/IR/BuiltinAttributeInterfaces.h"
2820
#include "mlir/IR/BuiltinTypes.h"
29-
#include "mlir/IR/ImplicitLocOpBuilder.h"
3021
#include "mlir/IR/Location.h"
31-
#include "mlir/IR/Matchers.h"
3222
#include "mlir/IR/PatternMatch.h"
3323
#include "mlir/IR/TypeUtilities.h"
34-
#include "mlir/Interfaces/VectorInterfaces.h"
3524

3625
#define DEBUG_TYPE "vector-broadcast-lowering"
3726

@@ -86,8 +75,7 @@ class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
8675
VectorType resType = VectorType::Builder(dstType).dropDim(0);
8776
Value bcst =
8877
rewriter.create<vector::BroadcastOp>(loc, resType, op.getSource());
89-
Value result = rewriter.create<arith::ConstantOp>(
90-
loc, dstType, rewriter.getZeroAttr(dstType));
78+
Value result = rewriter.create<ub::PoisonOp>(loc, dstType);
9179
for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d)
9280
result = rewriter.create<vector::InsertOp>(loc, bcst, result, d);
9381
rewriter.replaceOp(op, result);
@@ -127,8 +115,7 @@ class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
127115
VectorType resType =
128116
VectorType::get(dstType.getShape().drop_front(), eltType,
129117
dstType.getScalableDims().drop_front());
130-
Value result = rewriter.create<arith::ConstantOp>(
131-
loc, dstType, rewriter.getZeroAttr(dstType));
118+
Value result = rewriter.create<ub::PoisonOp>(loc, dstType);
132119
if (m == 0) {
133120
// Stetch at start.
134121
Value ext = rewriter.create<vector::ExtractOp>(loc, op.getSource(), 0);

mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,26 +11,19 @@
1111
//
1212
//===----------------------------------------------------------------------===//
1313

14-
#include "mlir/Dialect/Affine/IR/AffineOps.h"
1514
#include "mlir/Dialect/Arith/IR/Arith.h"
16-
#include "mlir/Dialect/Arith/Utils/Utils.h"
17-
#include "mlir/Dialect/Linalg/IR/Linalg.h"
1815
#include "mlir/Dialect/MemRef/IR/MemRef.h"
19-
#include "mlir/Dialect/SCF/IR/SCF.h"
20-
#include "mlir/Dialect/Tensor/IR/Tensor.h"
16+
#include "mlir/Dialect/UB/IR/UBOps.h"
2117
#include "mlir/Dialect/Utils/IndexingUtils.h"
2218
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
2319
#include "mlir/Dialect/Vector/IR/VectorOps.h"
2420
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
2521
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
26-
#include "mlir/IR/BuiltinAttributeInterfaces.h"
2722
#include "mlir/IR/BuiltinTypes.h"
2823
#include "mlir/IR/ImplicitLocOpBuilder.h"
2924
#include "mlir/IR/Location.h"
30-
#include "mlir/IR/Matchers.h"
3125
#include "mlir/IR/PatternMatch.h"
3226
#include "mlir/IR/TypeUtilities.h"
33-
#include "mlir/Interfaces/VectorInterfaces.h"
3427

3528
#define DEBUG_TYPE "lower-vector-transpose"
3629

@@ -291,8 +284,7 @@ static Value transposeToShuffle16x16(OpBuilder &builder, Value source, int m,
291284

292285
auto reshInputType = VectorType::get(
293286
{m, n}, cast<VectorType>(source.getType()).getElementType());
294-
Value res =
295-
b.create<arith::ConstantOp>(reshInputType, b.getZeroAttr(reshInputType));
287+
Value res = b.create<ub::PoisonOp>(reshInputType);
296288
for (int64_t i = 0; i < m; ++i)
297289
res = b.create<vector::InsertOp>(vs[i], res, i);
298290
return res;
@@ -368,8 +360,7 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
368360
// of the leftmost transposed dimensions. We traverse every transpose
369361
// element using a linearized index that we delinearize to generate the
370362
// appropriate indices for the extract/insert operations.
371-
Value result = rewriter.create<arith::ConstantOp>(
372-
loc, resType, rewriter.getZeroAttr(resType));
363+
Value result = rewriter.create<ub::PoisonOp>(loc, resType);
373364
int64_t numTransposedElements = ShapedType::getNumElements(prunedInShape);
374365

375366
for (int64_t linearIdx = 0; linearIdx < numTransposedElements;

mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,7 @@ func.func @broadcast_vec2d_from_vec0d(%arg0: vector<f32>) -> vector<3x2xf32> {
286286
// CHECK-LABEL: @broadcast_vec2d_from_vec0d(
287287
// CHECK-SAME: %[[A:.*]]: vector<f32>)
288288
// CHECK: %[[T0:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<f32> to vector<1xf32>
289-
// CHECK: %[[T1:.*]] = arith.constant dense<0.000000e+00> : vector<3x2xf32>
289+
// CHECK: %[[T1:.*]] = ub.poison : vector<3x2xf32>
290290
// CHECK: %[[T2:.*]] = builtin.unrealized_conversion_cast %[[T1]] : vector<3x2xf32> to !llvm.array<3 x vector<2xf32>>
291291
// CHECK: %[[T4:.*]] = llvm.mlir.constant(0 : index) : i64
292292
// CHECK: %[[T5:.*]] = llvm.extractelement %[[T0]][%[[T4]] : i64] : vector<1xf32>
@@ -306,7 +306,7 @@ func.func @broadcast_vec2d_from_vec1d(%arg0: vector<2xf32>) -> vector<3x2xf32> {
306306
}
307307
// CHECK-LABEL: @broadcast_vec2d_from_vec1d(
308308
// CHECK-SAME: %[[A:.*]]: vector<2xf32>)
309-
// CHECK: %[[T0:.*]] = arith.constant dense<0.000000e+00> : vector<3x2xf32>
309+
// CHECK: %[[T0:.*]] = ub.poison : vector<3x2xf32>
310310
// CHECK: %[[T1:.*]] = builtin.unrealized_conversion_cast %[[T0]] : vector<3x2xf32> to !llvm.array<3 x vector<2xf32>>
311311
// CHECK: %[[T2:.*]] = llvm.insertvalue %[[A]], %[[T1]][0] : !llvm.array<3 x vector<2xf32>>
312312
// CHECK: %[[T3:.*]] = llvm.insertvalue %[[A]], %[[T2]][1] : !llvm.array<3 x vector<2xf32>>
@@ -322,7 +322,7 @@ func.func @broadcast_vec2d_from_vec1d_scalable(%arg0: vector<[2]xf32>) -> vector
322322
}
323323
// CHECK-LABEL: @broadcast_vec2d_from_vec1d_scalable(
324324
// CHECK-SAME: %[[A:.*]]: vector<[2]xf32>)
325-
// CHECK: %[[T0:.*]] = arith.constant dense<0.000000e+00> : vector<3x[2]xf32>
325+
// CHECK: %[[T0:.*]] = ub.poison : vector<3x[2]xf32>
326326
// CHECK: %[[T1:.*]] = builtin.unrealized_conversion_cast %[[T0]] : vector<3x[2]xf32> to !llvm.array<3 x vector<[2]xf32>>
327327
// CHECK: %[[T2:.*]] = llvm.insertvalue %[[A]], %[[T1]][0] : !llvm.array<3 x vector<[2]xf32>>
328328
// CHECK: %[[T3:.*]] = llvm.insertvalue %[[A]], %[[T2]][1] : !llvm.array<3 x vector<[2]xf32>>
@@ -339,7 +339,7 @@ func.func @broadcast_vec2d_from_index_vec1d(%arg0: vector<2xindex>) -> vector<3x
339339
// CHECK-LABEL: @broadcast_vec2d_from_index_vec1d(
340340
// CHECK-SAME: %[[A:.*]]: vector<2xindex>)
341341
// CHECK: %[[T1:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<2xindex> to vector<2xi64>
342-
// CHECK: %[[T0:.*]] = arith.constant dense<0> : vector<3x2xindex>
342+
// CHECK: %[[T0:.*]] = ub.poison : vector<3x2xindex>
343343
// CHECK: %[[T2:.*]] = builtin.unrealized_conversion_cast %[[T0]] : vector<3x2xindex> to !llvm.array<3 x vector<2xi64>>
344344
// CHECK: %[[T3:.*]] = llvm.insertvalue %[[T1]], %[[T2]][0] : !llvm.array<3 x vector<2xi64>>
345345

@@ -355,7 +355,7 @@ func.func @broadcast_vec2d_from_index_vec1d_scalable(%arg0: vector<[2]xindex>) -
355355
// CHECK-LABEL: @broadcast_vec2d_from_index_vec1d_scalable(
356356
// CHECK-SAME: %[[A:.*]]: vector<[2]xindex>)
357357
// CHECK: %[[T1:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<[2]xindex> to vector<[2]xi64>
358-
// CHECK: %[[T0:.*]] = arith.constant dense<0> : vector<3x[2]xindex>
358+
// CHECK: %[[T0:.*]] = ub.poison : vector<3x[2]xindex>
359359
// CHECK: %[[T2:.*]] = builtin.unrealized_conversion_cast %[[T0]] : vector<3x[2]xindex> to !llvm.array<3 x vector<[2]xi64>>
360360
// CHECK: %[[T3:.*]] = llvm.insertvalue %[[T1]], %[[T2]][0] : !llvm.array<3 x vector<[2]xi64>>
361361

@@ -370,9 +370,9 @@ func.func @broadcast_vec3d_from_vec1d(%arg0: vector<2xf32>) -> vector<4x3x2xf32>
370370
}
371371
// CHECK-LABEL: @broadcast_vec3d_from_vec1d(
372372
// CHECK-SAME: %[[A:.*]]: vector<2xf32>)
373-
// CHECK-DAG: %[[T0:.*]] = arith.constant dense<0.000000e+00> : vector<3x2xf32>
373+
// CHECK-DAG: %[[T0:.*]] = ub.poison : vector<3x2xf32>
374374
// CHECK-DAG: %[[T2:.*]] = builtin.unrealized_conversion_cast %[[T0]] : vector<3x2xf32> to !llvm.array<3 x vector<2xf32>>
375-
// CHECK-DAG: %[[T1:.*]] = arith.constant dense<0.000000e+00> : vector<4x3x2xf32>
375+
// CHECK-DAG: %[[T1:.*]] = ub.poison : vector<4x3x2xf32>
376376
// CHECK-DAG: %[[T6:.*]] = builtin.unrealized_conversion_cast %[[T1]] : vector<4x3x2xf32> to !llvm.array<4 x array<3 x vector<2xf32>>>
377377

378378
// CHECK: %[[T3:.*]] = llvm.insertvalue %[[A]], %[[T2]][0] : !llvm.array<3 x vector<2xf32>>
@@ -395,9 +395,9 @@ func.func @broadcast_vec3d_from_vec1d_scalable(%arg0: vector<[2]xf32>) -> vector
395395
}
396396
// CHECK-LABEL: @broadcast_vec3d_from_vec1d_scalable(
397397
// CHECK-SAME: %[[A:.*]]: vector<[2]xf32>)
398-
// CHECK-DAG: %[[T0:.*]] = arith.constant dense<0.000000e+00> : vector<3x[2]xf32>
398+
// CHECK-DAG: %[[T0:.*]] = ub.poison : vector<3x[2]xf32>
399399
// CHECK-DAG: %[[T2:.*]] = builtin.unrealized_conversion_cast %[[T0]] : vector<3x[2]xf32> to !llvm.array<3 x vector<[2]xf32>>
400-
// CHECK-DAG: %[[T1:.*]] = arith.constant dense<0.000000e+00> : vector<4x3x[2]xf32>
400+
// CHECK-DAG: %[[T1:.*]] = ub.poison : vector<4x3x[2]xf32>
401401
// CHECK-DAG: %[[T6:.*]] = builtin.unrealized_conversion_cast %[[T1]] : vector<4x3x[2]xf32> to !llvm.array<4 x array<3 x vector<[2]xf32>>>
402402

403403
// CHECK: %[[T3:.*]] = llvm.insertvalue %[[A]], %[[T2]][0] : !llvm.array<3 x vector<[2]xf32>>
@@ -421,7 +421,7 @@ func.func @broadcast_vec3d_from_vec2d(%arg0: vector<3x2xf32>) -> vector<4x3x2xf3
421421
// CHECK-LABEL: @broadcast_vec3d_from_vec2d(
422422
// CHECK-SAME: %[[A:.*]]: vector<3x2xf32>)
423423
// CHECK: %[[T1:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<3x2xf32> to !llvm.array<3 x vector<2xf32>>
424-
// CHECK: %[[T0:.*]] = arith.constant dense<0.000000e+00> : vector<4x3x2xf32>
424+
// CHECK: %[[T0:.*]] = ub.poison : vector<4x3x2xf32>
425425
// CHECK: %[[T2:.*]] = builtin.unrealized_conversion_cast %[[T0]] : vector<4x3x2xf32> to !llvm.array<4 x array<3 x vector<2xf32>>>
426426
// CHECK: %[[T3:.*]] = llvm.insertvalue %[[T1]], %[[T2]][0] : !llvm.array<4 x array<3 x vector<2xf32>>>
427427
// CHECK: %[[T5:.*]] = llvm.insertvalue %[[T1]], %[[T3]][1] : !llvm.array<4 x array<3 x vector<2xf32>>>
@@ -439,7 +439,7 @@ func.func @broadcast_vec3d_from_vec2d_scalable(%arg0: vector<3x[2]xf32>) -> vect
439439
// CHECK-LABEL: @broadcast_vec3d_from_vec2d_scalable(
440440
// CHECK-SAME: %[[A:.*]]: vector<3x[2]xf32>)
441441
// CHECK: %[[T1:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<3x[2]xf32> to !llvm.array<3 x vector<[2]xf32>>
442-
// CHECK: %[[T0:.*]] = arith.constant dense<0.000000e+00> : vector<4x3x[2]xf32>
442+
// CHECK: %[[T0:.*]] = ub.poison : vector<4x3x[2]xf32>
443443
// CHECK: %[[T2:.*]] = builtin.unrealized_conversion_cast %[[T0]] : vector<4x3x[2]xf32> to !llvm.array<4 x array<3 x vector<[2]xf32>>>
444444
// CHECK: %[[T3:.*]] = llvm.insertvalue %[[T1]], %[[T2]][0] : !llvm.array<4 x array<3 x vector<[2]xf32>>>
445445
// CHECK: %[[T5:.*]] = llvm.insertvalue %[[T1]], %[[T3]][1] : !llvm.array<4 x array<3 x vector<[2]xf32>>>
@@ -486,7 +486,7 @@ func.func @broadcast_stretch_at_start(%arg0: vector<1x4xf32>) -> vector<3x4xf32>
486486
// CHECK-LABEL: @broadcast_stretch_at_start(
487487
// CHECK-SAME: %[[A:.*]]: vector<1x4xf32>)
488488
// CHECK: %[[T2:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<1x4xf32> to !llvm.array<1 x vector<4xf32>>
489-
// CHECK: %[[T1:.*]] = arith.constant dense<0.000000e+00> : vector<3x4xf32>
489+
// CHECK: %[[T1:.*]] = ub.poison : vector<3x4xf32>
490490
// CHECK: %[[T4:.*]] = builtin.unrealized_conversion_cast %[[T1]] : vector<3x4xf32> to !llvm.array<3 x vector<4xf32>>
491491
// CHECK: %[[T3:.*]] = llvm.extractvalue %[[T2]][0] : !llvm.array<1 x vector<4xf32>>
492492
// CHECK: %[[T5:.*]] = llvm.insertvalue %[[T3]], %[[T4]][0] : !llvm.array<3 x vector<4xf32>>
@@ -504,7 +504,7 @@ func.func @broadcast_stretch_at_start_scalable(%arg0: vector<1x[4]xf32>) -> vect
504504
// CHECK-LABEL: @broadcast_stretch_at_start_scalable(
505505
// CHECK-SAME: %[[A:.*]]: vector<1x[4]xf32>)
506506
// CHECK: %[[T2:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<1x[4]xf32> to !llvm.array<1 x vector<[4]xf32>>
507-
// CHECK: %[[T1:.*]] = arith.constant dense<0.000000e+00> : vector<3x[4]xf32>
507+
// CHECK: %[[T1:.*]] = ub.poison : vector<3x[4]xf32>
508508
// CHECK: %[[T4:.*]] = builtin.unrealized_conversion_cast %[[T1]] : vector<3x[4]xf32> to !llvm.array<3 x vector<[4]xf32>>
509509
// CHECK: %[[T3:.*]] = llvm.extractvalue %[[T2]][0] : !llvm.array<1 x vector<[4]xf32>>
510510
// CHECK: %[[T5:.*]] = llvm.insertvalue %[[T3]], %[[T4]][0] : !llvm.array<3 x vector<[4]xf32>>
@@ -522,7 +522,7 @@ func.func @broadcast_stretch_at_end(%arg0: vector<4x1xf32>) -> vector<4x3xf32> {
522522
// CHECK-LABEL: @broadcast_stretch_at_end(
523523
// CHECK-SAME: %[[A:.*]]: vector<4x1xf32>)
524524
// CHECK: %[[T2:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<4x1xf32> to !llvm.array<4 x vector<1xf32>>
525-
// CHECK: %[[T1:.*]] = arith.constant dense<0.000000e+00> : vector<4x3xf32>
525+
// CHECK: %[[T1:.*]] = ub.poison : vector<4x3xf32>
526526
// CHECK: %[[T7:.*]] = builtin.unrealized_conversion_cast %[[T1]] : vector<4x3xf32> to !llvm.array<4 x vector<3xf32>>
527527
// CHECK: %[[T3:.*]] = llvm.extractvalue %[[T2]][0] : !llvm.array<4 x vector<1xf32>>
528528
// CHECK: %[[T4:.*]] = llvm.mlir.constant(0 : i64) : i64
@@ -570,9 +570,9 @@ func.func @broadcast_stretch_in_middle(%arg0: vector<4x1x2xf32>) -> vector<4x3x2
570570
// CHECK-LABEL: @broadcast_stretch_in_middle(
571571
// CHECK-SAME: %[[A:.*]]: vector<4x1x2xf32>) -> vector<4x3x2xf32> {
572572
// CHECK: %[[T3:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<4x1x2xf32> to !llvm.array<4 x array<1 x vector<2xf32>>>
573-
// CHECK: %[[T1:.*]] = arith.constant dense<0.000000e+00> : vector<4x3x2xf32>
573+
// CHECK: %[[T1:.*]] = ub.poison : vector<4x3x2xf32>
574574
// CHECK: %[[T9:.*]] = builtin.unrealized_conversion_cast %[[T1]] : vector<4x3x2xf32> to !llvm.array<4 x array<3 x vector<2xf32>>>
575-
// CHECK: %[[T2:.*]] = arith.constant dense<0.000000e+00> : vector<3x2xf32>
575+
// CHECK: %[[T2:.*]] = ub.poison : vector<3x2xf32>
576576
// CHECK: %[[T5:.*]] = builtin.unrealized_conversion_cast %[[T2]] : vector<3x2xf32> to !llvm.array<3 x vector<2xf32>>
577577
// CHECK: %[[T4:.*]] = llvm.extractvalue %[[T3]][0, 0] : !llvm.array<4 x array<1 x vector<2xf32>>>
578578
// CHECK: %[[T6:.*]] = llvm.insertvalue %[[T4]], %[[T5]][0] : !llvm.array<3 x vector<2xf32>>
@@ -606,9 +606,9 @@ func.func @broadcast_stretch_in_middle_scalable_v1(%arg0: vector<4x1x[2]xf32>) -
606606
// CHECK-LABEL: @broadcast_stretch_in_middle_scalable_v1(
607607
// CHECK-SAME: %[[A:.*]]: vector<4x1x[2]xf32>) -> vector<4x3x[2]xf32> {
608608
// CHECK: %[[T3:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<4x1x[2]xf32> to !llvm.array<4 x array<1 x vector<[2]xf32>>>
609-
// CHECK: %[[T1:.*]] = arith.constant dense<0.000000e+00> : vector<4x3x[2]xf32>
609+
// CHECK: %[[T1:.*]] = ub.poison : vector<4x3x[2]xf32>
610610
// CHECK: %[[T9:.*]] = builtin.unrealized_conversion_cast %[[T1]] : vector<4x3x[2]xf32> to !llvm.array<4 x array<3 x vector<[2]xf32>>>
611-
// CHECK: %[[T2:.*]] = arith.constant dense<0.000000e+00> : vector<3x[2]xf32>
611+
// CHECK: %[[T2:.*]] = ub.poison : vector<3x[2]xf32>
612612
// CHECK: %[[T5:.*]] = builtin.unrealized_conversion_cast %[[T2]] : vector<3x[2]xf32> to !llvm.array<3 x vector<[2]xf32>>
613613
// CHECK: %[[T4:.*]] = llvm.extractvalue %[[T3]][0, 0] : !llvm.array<4 x array<1 x vector<[2]xf32>>>
614614
// CHECK: %[[T6:.*]] = llvm.insertvalue %[[T4]], %[[T5]][0] : !llvm.array<3 x vector<[2]xf32>>

mlir/test/Dialect/Vector/vector-bitcast-lowering-transforms.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ func.func @vector_bitcast_2d(%arg0: vector<2x4xi32>) -> vector<2x2xi64> {
2424
}
2525
// CHECK-LABEL: func.func @vector_bitcast_2d
2626
// CHECK-SAME: %[[IN:[a-zA-Z0-9]+]]
27-
// CHECK: %[[INIT:.+]] = arith.constant {{.+}} : vector<2x2xi64>
27+
// CHECK: %[[INIT:.+]] = ub.poison : vector<2x2xi64>
2828
// CHECK: %[[V1:.+]] = vector.extract %[[IN]][0] : vector<4xi32> from vector<2x4xi32>
2929
// CHECK: %[[B1:.+]] = vector.bitcast %[[V1]] : vector<4xi32> to vector<2xi64>
3030
// CHECK: %[[R1:.+]] = vector.insert %[[B1]], %[[INIT]] [0]
@@ -39,7 +39,7 @@ func.func @vector_bitcast_4d_with_scalable_dim(%arg0: vector<1x2x[3]x4xi64>) ->
3939
}
4040
// CHECK-LABEL: func.func @vector_bitcast_4d_with_scalable_dim
4141
// CHECK-SAME: %[[IN:[a-zA-Z0-9]+]]
42-
// CHECK: %[[INIT:.+]] = arith.constant dense<0> : vector<1x2x[3]x8xi32>
42+
// CHECK: %[[INIT:.+]] = ub.poison : vector<1x2x[3]x8xi32>
4343
// CHECK: %[[V1:.+]] = vector.extract %[[IN]][0, 0] : vector<[3]x4xi64> from vector<1x2x[3]x4xi64>
4444
// CHECK: %[[B1:.+]] = vector.bitcast %[[V1]] : vector<[3]x4xi64> to vector<[3]x8xi32>
4545
// CHECK: %[[R1:.+]] = vector.insert %[[B1]], %[[INIT]] [0, 0] : vector<[3]x8xi32> into vector<1x2x[3]x8xi32>
@@ -54,7 +54,7 @@ func.func @vector_bitcast_2d_trailing_scalable_dim(%arg0: vector<2x[2]xi64>) ->
5454
}
5555
// CHECK-LABEL: func.func @vector_bitcast_2d_trailing_scalable_dim
5656
// CHECK-SAME: %[[IN:[a-zA-Z0-9]+]]
57-
// CHECK: %[[INIT:.+]] = arith.constant dense<0> : vector<2x[4]xi32>
57+
// CHECK: %[[INIT:.+]] = ub.poison : vector<2x[4]xi32>
5858
// CHECK: %[[V1:.+]] = vector.extract %[[IN]][0] : vector<[2]xi64> from vector<2x[2]xi64>
5959
// CHECK: %[[B1:.+]] = vector.bitcast %[[V1]] : vector<[2]xi64> to vector<[4]xi32>
6060
// CHECK: %[[R1:.+]] = vector.insert %[[B1]], %[[INIT]] [0] : vector<[4]xi32> into vector<2x[4]xi32>

0 commit comments

Comments
 (0)