@@ -214,13 +214,13 @@ Value VectorContractRewriter::createMMLA(PatternRewriter &rewriter,
214
214
215
215
switch (mmlaOp) {
216
216
case MMLA::SignedInt:
217
- return rewriter. create < arm_sve::SmmlaOp>( loc, resTy, acc, lhs, rhs);
217
+ return arm_sve::SmmlaOp::create (rewriter, loc, resTy, acc, lhs, rhs);
218
218
case MMLA::UnsignedInt:
219
- return rewriter. create < arm_sve::UmmlaOp>( loc, resTy, acc, lhs, rhs);
219
+ return arm_sve::UmmlaOp::create (rewriter, loc, resTy, acc, lhs, rhs);
220
220
case MMLA::MixedInt:
221
- return rewriter. create < arm_sve::UsmmlaOp>( loc, resTy, acc, lhs, rhs);
221
+ return arm_sve::UsmmlaOp::create (rewriter, loc, resTy, acc, lhs, rhs);
222
222
case MMLA::Bfloat:
223
- return rewriter. create < arm_sve::BfmmlaOp>( loc, resTy, acc, lhs, rhs);
223
+ return arm_sve::BfmmlaOp::create (rewriter, loc, resTy, acc, lhs, rhs);
224
224
default :
225
225
llvm_unreachable (" Uninitialized operation kind" );
226
226
}
@@ -316,62 +316,63 @@ Value VectorContractRewriter::lower(vector::ContractionOp op,
316
316
for (int64_t i = 0 ; i < M; i += 2 ) {
317
317
// Extract two consecutive rows of the LHS tile.
318
318
auto r0 =
319
- rewriter. create < vector::ExtractOp>( loc, lhs, ArrayRef<int64_t >{i});
319
+ vector::ExtractOp::create (rewriter, loc, lhs, ArrayRef<int64_t >{i});
320
320
auto r1 =
321
- rewriter. create < vector::ExtractOp>( loc, lhs, ArrayRef<int64_t >{i + 1 });
321
+ vector::ExtractOp::create (rewriter, loc, lhs, ArrayRef<int64_t >{i + 1 });
322
322
// Concatenate to obtain a 2 x K x <input-type> flattened sub-tile.
323
323
SmallVector<int64_t > shuffleIdx (2 * K);
324
324
std::iota (shuffleIdx.begin (), shuffleIdx.end (), 0 );
325
- auto t = rewriter. create < vector::ShuffleOp>( loc, r0, r1, shuffleIdx);
325
+ auto t = vector::ShuffleOp::create (rewriter, loc, r0, r1, shuffleIdx);
326
326
// Turn it into a scalable vector.
327
- auto s = rewriter. create < vector::ScalableInsertOp> (
328
- loc, t, rewriter. create < ub::PoisonOp>( loc, flatLhsType), 0 );
327
+ auto s = vector::ScalableInsertOp::create (
328
+ rewriter, loc, t, ub::PoisonOp::create (rewriter, loc, flatLhsType), 0 );
329
329
// Replicate the sub-tile VSCALE times to fill the entire vector.
330
- auto r = rewriter. create < arm_sve::DupQLaneOp>( loc, s, 0 );
330
+ auto r = arm_sve::DupQLaneOp::create (rewriter, loc, s, 0 );
331
331
lhsTile.push_back (r);
332
332
}
333
333
334
334
// "Flatten" the RHS tile from <[N]xK> to <[N*K]>.
335
- auto rhs = rewriter. create < vector::ShapeCastOp>( this ->rhs .getLoc (),
336
- flatRhsTileType, this ->rhs );
335
+ auto rhs = vector::ShapeCastOp::create (rewriter, this ->rhs .getLoc (),
336
+ flatRhsTileType, this ->rhs );
337
337
338
338
// Extract the RHS sub-tiles with logical shape <Kx[2]>.
339
339
SmallVector<Value> rhsTile;
340
340
for (int64_t j = 0 ; j < N; j += 2 )
341
- rhsTile.push_back (rewriter. create < vector::ScalableExtractOp> (
342
- loc, flatRhsType, rhs, j * K));
341
+ rhsTile.push_back (vector::ScalableExtractOp::create (
342
+ rewriter, loc, flatRhsType, rhs, j * K));
343
343
344
344
// Extract and pack the ACC sub-tiles.
345
345
SmallVector<Value> accTile;
346
346
for (int64_t i = 0 ; i < M; i += 2 ) {
347
347
// Extract two consecutive rows of the accumulator tile.
348
- auto r0 = rewriter. create < vector::ExtractOp>( loc, op.getAcc (),
349
- ArrayRef<int64_t >{i});
350
- auto r1 = rewriter. create < vector::ExtractOp>( loc, op.getAcc (),
351
- ArrayRef<int64_t >{i + 1 });
348
+ auto r0 = vector::ExtractOp::create (rewriter, loc, op.getAcc (),
349
+ ArrayRef<int64_t >{i});
350
+ auto r1 = vector::ExtractOp::create (rewriter, loc, op.getAcc (),
351
+ ArrayRef<int64_t >{i + 1 });
352
352
Value accTileVec;
353
353
if (swapOperands) {
354
354
// We are performing the operation with swapped LHS and RHS we need to
355
355
// transpose each individual 2x2 tile of the accumulator and (later) the
356
356
// final result.
357
- accTileVec = rewriter. create < vector::InterleaveOp>( loc, r0, r1);
357
+ accTileVec = vector::InterleaveOp::create (rewriter, loc, r0, r1);
358
358
} else {
359
359
// Bitcast accumulator rows to double-width integer elements, so
360
360
// subsequent interleave/deinterleave work on pairs of elements.
361
- auto r0I64 = rewriter. create < vector::BitCastOp>( loc, accRow64Ty, r0);
362
- auto r1I64 = rewriter. create < vector::BitCastOp>( loc, accRow64Ty, r1);
361
+ auto r0I64 = vector::BitCastOp::create (rewriter, loc, accRow64Ty, r0);
362
+ auto r1I64 = vector::BitCastOp::create (rewriter, loc, accRow64Ty, r1);
363
363
364
364
// Interleave the rows, effectively flattening each 2x2 tile into 4
365
365
// consecutive elements.
366
- auto intrI64 = rewriter. create < vector::InterleaveOp>( loc, r0I64, r1I64);
366
+ auto intrI64 = vector::InterleaveOp::create (rewriter, loc, r0I64, r1I64);
367
367
368
368
// Bitcast back to original element type.
369
- accTileVec = rewriter.create <vector::BitCastOp>(loc, accRowX2Ty, intrI64);
369
+ accTileVec =
370
+ vector::BitCastOp::create (rewriter, loc, accRowX2Ty, intrI64);
370
371
}
371
372
// Extract ACC sub-tiles.
372
373
for (int64_t j = 0 ; j < N; j += 2 )
373
- accTile.push_back (rewriter. create < vector::ScalableExtractOp> (
374
- loc, flatAccType, accTileVec, j * 2 ));
374
+ accTile.push_back (vector::ScalableExtractOp::create (
375
+ rewriter, loc, flatAccType, accTileVec, j * 2 ));
375
376
}
376
377
377
378
// Emit sub-tile matrix multiplications.
@@ -384,36 +385,36 @@ Value VectorContractRewriter::lower(vector::ContractionOp op,
384
385
}
385
386
386
387
// Unpack the OUT sub-tiles and insert into the result.
387
- Value result = rewriter. create < ub::PoisonOp>( loc, op.getResultType ());
388
+ Value result = ub::PoisonOp::create (rewriter, loc, op.getResultType ());
388
389
for (int64_t i = 0 ; i < M / 2 ; ++i) {
389
390
// Collect a number of sub-tiles in a row.
390
- Value row = rewriter. create < ub::PoisonOp>( loc, accRowX2Ty);
391
+ Value row = ub::PoisonOp::create (rewriter, loc, accRowX2Ty);
391
392
for (int64_t j = 0 ; j < N / 2 ; ++j)
392
- row = rewriter. create < vector::ScalableInsertOp> (
393
- loc, outTile[i * N / 2 + j], row, j * 4 );
393
+ row = vector::ScalableInsertOp::create (
394
+ rewriter, loc, outTile[i * N / 2 + j], row, j * 4 );
394
395
395
396
// Unpack the row to obtain two rows of the output. If we have the out
396
397
// sub-tiles transposed we obtain two consecutive output rows by
397
398
// separating even and odd elements, i.e. a simple deinterleave.
398
399
// Otherwise, the interleave is by pairs.
399
400
Value out0, out1;
400
401
if (swapOperands) {
401
- auto tmp = rewriter. create < vector::DeinterleaveOp>( loc, row);
402
+ auto tmp = vector::DeinterleaveOp::create (rewriter, loc, row);
402
403
out0 = tmp.getRes1 ();
403
404
out1 = tmp.getRes2 ();
404
405
} else {
405
406
// Deinterleave by pairs.
406
- auto row64 = rewriter. create < vector::BitCastOp>( loc, accRowX264Ty, row);
407
- auto deintr64 = rewriter. create < vector::DeinterleaveOp>( loc, row64);
407
+ auto row64 = vector::BitCastOp::create (rewriter, loc, accRowX264Ty, row);
408
+ auto deintr64 = vector::DeinterleaveOp::create (rewriter, loc, row64);
408
409
409
410
// Bitcast back into original element type and insert into the result.
410
- out0 =
411
- rewriter. create <vector::BitCastOp>(loc, accRowTy, deintr64.getRes1 ());
412
- out1 =
413
- rewriter. create <vector::BitCastOp>(loc, accRowTy, deintr64.getRes2 ());
411
+ out0 = vector::BitCastOp::create (rewriter, loc, accRowTy,
412
+ deintr64.getRes1 ());
413
+ out1 = vector::BitCastOp::create (rewriter, loc, accRowTy,
414
+ deintr64.getRes2 ());
414
415
}
415
- result = rewriter. create < vector::InsertOp>( loc, out0, result, i * 2 );
416
- result = rewriter. create < vector::InsertOp>( loc, out1, result, i * 2 + 1 );
416
+ result = vector::InsertOp::create (rewriter, loc, out0, result, i * 2 );
417
+ result = vector::InsertOp::create (rewriter, loc, out1, result, i * 2 + 1 );
417
418
}
418
419
419
420
return result;
0 commit comments