@@ -214,13 +214,13 @@ Value VectorContractRewriter::createMMLA(PatternRewriter &rewriter,
214
214
215
215
switch (mmlaOp) {
216
216
case MMLA::SignedInt:
217
- return rewriter.create< arm_sve::SmmlaOp>( loc, resTy, acc, lhs, rhs);
217
+ return arm_sve::SmmlaOp::create(rewriter, loc, resTy, acc, lhs, rhs);
218
218
case MMLA::UnsignedInt:
219
- return rewriter.create< arm_sve::UmmlaOp>( loc, resTy, acc, lhs, rhs);
219
+ return arm_sve::UmmlaOp::create(rewriter, loc, resTy, acc, lhs, rhs);
220
220
case MMLA::MixedInt:
221
- return rewriter.create< arm_sve::UsmmlaOp>( loc, resTy, acc, lhs, rhs);
221
+ return arm_sve::UsmmlaOp::create(rewriter, loc, resTy, acc, lhs, rhs);
222
222
case MMLA::Bfloat:
223
- return rewriter.create< arm_sve::BfmmlaOp>( loc, resTy, acc, lhs, rhs);
223
+ return arm_sve::BfmmlaOp::create(rewriter, loc, resTy, acc, lhs, rhs);
224
224
default:
225
225
llvm_unreachable("Uninitialized operation kind");
226
226
}
@@ -316,62 +316,63 @@ Value VectorContractRewriter::lower(vector::ContractionOp op,
316
316
for (int64_t i = 0; i < M; i += 2) {
317
317
// Extract two consecutive rows of the LHS tile.
318
318
auto r0 =
319
- rewriter.create< vector::ExtractOp>( loc, lhs, ArrayRef<int64_t>{i});
319
+ vector::ExtractOp::create(rewriter, loc, lhs, ArrayRef<int64_t>{i});
320
320
auto r1 =
321
- rewriter.create< vector::ExtractOp>( loc, lhs, ArrayRef<int64_t>{i + 1});
321
+ vector::ExtractOp::create(rewriter, loc, lhs, ArrayRef<int64_t>{i + 1});
322
322
// Concatenate to obtain a 2 x K x <input-type> flattened sub-tile.
323
323
SmallVector<int64_t> shuffleIdx(2 * K);
324
324
std::iota(shuffleIdx.begin(), shuffleIdx.end(), 0);
325
- auto t = rewriter.create< vector::ShuffleOp>( loc, r0, r1, shuffleIdx);
325
+ auto t = vector::ShuffleOp::create(rewriter, loc, r0, r1, shuffleIdx);
326
326
// Turn it into a scalable vector.
327
- auto s = rewriter.create< vector::ScalableInsertOp> (
328
- loc, t, rewriter.create< ub::PoisonOp>( loc, flatLhsType), 0);
327
+ auto s = vector::ScalableInsertOp::create (
328
+ rewriter, loc, t, ub::PoisonOp::create(rewriter, loc, flatLhsType), 0);
329
329
// Replicate the sub-tile VSCALE times to fill the entire vector.
330
- auto r = rewriter.create< arm_sve::DupQLaneOp>( loc, s, 0);
330
+ auto r = arm_sve::DupQLaneOp::create(rewriter, loc, s, 0);
331
331
lhsTile.push_back(r);
332
332
}
333
333
334
334
// "Flatten" the RHS tile from <[N]xK> to <[N*K]>.
335
- auto rhs = rewriter.create< vector::ShapeCastOp>( this->rhs.getLoc(),
336
- flatRhsTileType, this->rhs);
335
+ auto rhs = vector::ShapeCastOp::create(rewriter, this->rhs.getLoc(),
336
+ flatRhsTileType, this->rhs);
337
337
338
338
// Extract the RHS sub-tiles with logical shape <Kx[2]>.
339
339
SmallVector<Value> rhsTile;
340
340
for (int64_t j = 0; j < N; j += 2)
341
- rhsTile.push_back(rewriter.create< vector::ScalableExtractOp> (
342
- loc, flatRhsType, rhs, j * K));
341
+ rhsTile.push_back(vector::ScalableExtractOp::create (
342
+ rewriter, loc, flatRhsType, rhs, j * K));
343
343
344
344
// Extract and pack the ACC sub-tiles.
345
345
SmallVector<Value> accTile;
346
346
for (int64_t i = 0; i < M; i += 2) {
347
347
// Extract two consecutive rows of the accumulator tile.
348
- auto r0 = rewriter.create< vector::ExtractOp>( loc, op.getAcc(),
349
- ArrayRef<int64_t>{i});
350
- auto r1 = rewriter.create< vector::ExtractOp>( loc, op.getAcc(),
351
- ArrayRef<int64_t>{i + 1});
348
+ auto r0 = vector::ExtractOp::create(rewriter, loc, op.getAcc(),
349
+ ArrayRef<int64_t>{i});
350
+ auto r1 = vector::ExtractOp::create(rewriter, loc, op.getAcc(),
351
+ ArrayRef<int64_t>{i + 1});
352
352
Value accTileVec;
353
353
if (swapOperands) {
354
354
// We are performing the operation with swapped LHS and RHS we need to
355
355
// transpose each individual 2x2 tile of the accumulator and (later) the
356
356
// final result.
357
- accTileVec = rewriter.create< vector::InterleaveOp>( loc, r0, r1);
357
+ accTileVec = vector::InterleaveOp::create(rewriter, loc, r0, r1);
358
358
} else {
359
359
// Bitcast accumulator rows to double-width integer elements, so
360
360
// subsequent interleave/deinterleave work on pairs of elements.
361
- auto r0I64 = rewriter.create< vector::BitCastOp>( loc, accRow64Ty, r0);
362
- auto r1I64 = rewriter.create< vector::BitCastOp>( loc, accRow64Ty, r1);
361
+ auto r0I64 = vector::BitCastOp::create(rewriter, loc, accRow64Ty, r0);
362
+ auto r1I64 = vector::BitCastOp::create(rewriter, loc, accRow64Ty, r1);
363
363
364
364
// Interleave the rows, effectively flattening each 2x2 tile into 4
365
365
// consecutive elements.
366
- auto intrI64 = rewriter.create< vector::InterleaveOp>( loc, r0I64, r1I64);
366
+ auto intrI64 = vector::InterleaveOp::create(rewriter, loc, r0I64, r1I64);
367
367
368
368
// Bitcast back to original element type.
369
- accTileVec = rewriter.create<vector::BitCastOp>(loc, accRowX2Ty, intrI64);
369
+ accTileVec =
370
+ vector::BitCastOp::create(rewriter, loc, accRowX2Ty, intrI64);
370
371
}
371
372
// Extract ACC sub-tiles.
372
373
for (int64_t j = 0; j < N; j += 2)
373
- accTile.push_back(rewriter.create< vector::ScalableExtractOp> (
374
- loc, flatAccType, accTileVec, j * 2));
374
+ accTile.push_back(vector::ScalableExtractOp::create (
375
+ rewriter, loc, flatAccType, accTileVec, j * 2));
375
376
}
376
377
377
378
// Emit sub-tile matrix multiplications.
@@ -384,36 +385,36 @@ Value VectorContractRewriter::lower(vector::ContractionOp op,
384
385
}
385
386
386
387
// Unpack the OUT sub-tiles and insert into the result.
387
- Value result = rewriter.create< ub::PoisonOp>( loc, op.getResultType());
388
+ Value result = ub::PoisonOp::create(rewriter, loc, op.getResultType());
388
389
for (int64_t i = 0; i < M / 2; ++i) {
389
390
// Collect a number of sub-tiles in a row.
390
- Value row = rewriter.create< ub::PoisonOp>( loc, accRowX2Ty);
391
+ Value row = ub::PoisonOp::create(rewriter, loc, accRowX2Ty);
391
392
for (int64_t j = 0; j < N / 2; ++j)
392
- row = rewriter.create< vector::ScalableInsertOp> (
393
- loc, outTile[i * N / 2 + j], row, j * 4);
393
+ row = vector::ScalableInsertOp::create (
394
+ rewriter, loc, outTile[i * N / 2 + j], row, j * 4);
394
395
395
396
// Unpack the row to obtain two rows of the output. If we have the out
396
397
// sub-tiles transposed we obtain two consecutive output rows by
397
398
// separating even and odd elements, i.e. a simple deinterleave.
398
399
// Otherwise, the interleave is by pairs.
399
400
Value out0, out1;
400
401
if (swapOperands) {
401
- auto tmp = rewriter.create< vector::DeinterleaveOp>( loc, row);
402
+ auto tmp = vector::DeinterleaveOp::create(rewriter, loc, row);
402
403
out0 = tmp.getRes1();
403
404
out1 = tmp.getRes2();
404
405
} else {
405
406
// Deinterleave by pairs.
406
- auto row64 = rewriter.create< vector::BitCastOp>( loc, accRowX264Ty, row);
407
- auto deintr64 = rewriter.create< vector::DeinterleaveOp>( loc, row64);
407
+ auto row64 = vector::BitCastOp::create(rewriter, loc, accRowX264Ty, row);
408
+ auto deintr64 = vector::DeinterleaveOp::create(rewriter, loc, row64);
408
409
409
410
// Bitcast back into original element type and insert into the result.
410
- out0 =
411
- rewriter.create<vector::BitCastOp>(loc, accRowTy, deintr64.getRes1());
412
- out1 =
413
- rewriter.create<vector::BitCastOp>(loc, accRowTy, deintr64.getRes2());
411
+ out0 = vector::BitCastOp::create(rewriter, loc, accRowTy,
412
+ deintr64.getRes1());
413
+ out1 = vector::BitCastOp::create(rewriter, loc, accRowTy,
414
+ deintr64.getRes2());
414
415
}
415
- result = rewriter.create< vector::InsertOp>( loc, out0, result, i * 2);
416
- result = rewriter.create< vector::InsertOp>( loc, out1, result, i * 2 + 1);
416
+ result = vector::InsertOp::create(rewriter, loc, out0, result, i * 2);
417
+ result = vector::InsertOp::create(rewriter, loc, out1, result, i * 2 + 1);
417
418
}
418
419
419
420
return result;
0 commit comments