Skip to content

Commit 70de874

Browse files
committed
Update according comments
1 parent 358ca60 commit 70de874

File tree

2 files changed

+51
-38
lines changed

2 files changed

+51
-38
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp

Lines changed: 38 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -296,38 +296,49 @@ static VectorValue emulatedVectorLoad(OpBuilder &rewriter, Location loc,
296296
newLoad);
297297
}
298298

299-
/// Selects values from two sources based on a mask, and casts the result to a
300-
/// new type.
301-
static Value selectAndCast(OpBuilder &builder, Location loc,
302-
VectorType castIntoType, Value mask, Value trueValue,
303-
Value falseValue) {
304-
Value maskedValue =
299+
/// Downcast two values to `downcastType`, then select values
300+
/// based on `mask`, and casts the result to `upcastType`.
301+
static Value downcastSelectAndUpcast(OpBuilder &builder, Location loc,
302+
VectorType downcastType,
303+
VectorType upcastType, Value mask,
304+
Value trueValue, Value falseValue) {
305+
assert(
306+
downcastType.getNumElements() * downcastType.getElementTypeBitWidth() ==
307+
upcastType.getNumElements() * upcastType.getElementTypeBitWidth() &&
308+
"expected upcastType size to be twice the size of downcastType");
309+
if (trueValue.getType() != downcastType)
310+
trueValue = builder.create<vector::BitCastOp>(loc, downcastType, trueValue);
311+
if (falseValue.getType() != downcastType)
312+
falseValue =
313+
builder.create<vector::BitCastOp>(loc, downcastType, falseValue);
314+
Value selectedType =
305315
builder.create<arith::SelectOp>(loc, mask, trueValue, falseValue);
306-
return builder.create<vector::BitCastOp>(loc, castIntoType, maskedValue);
316+
// Upcast the selected value to the new type.
317+
return builder.create<vector::BitCastOp>(loc, upcastType, selectedType);
307318
}
308319

309320
/// Emits `memref.generic_atomic_rmw` op to store a subbyte-sized value to a
310-
/// byte in memory, with a mask. The `valueToStore` is a vector of subbyte-sized
311-
/// elements, with size of 8 bits, and the mask is used to select which elements
312-
/// to store.
321+
/// byte in `linearizedMemref`, with a mask. The `valueToStore` is a vector of
322+
/// subbyte-sized elements, with size of 8 bits, and the mask is used to select
323+
/// which elements to store.
313324
///
314325
/// Inputs:
315326
/// linearizedMemref = |2|2|2|2| : <4xi2> (<1xi8>)
316-
/// linearizedIndex = 2
327+
/// storeIdx = 2
317328
/// valueToStore = |3|3|3|3| : vector<4xi2>
318329
/// mask = |0|0|1|1| : vector<4xi1>
319330
///
320331
/// Result:
321332
/// linearizedMemref = |2|2|3|3| : <4xi2> (<1xi8>)
322333
static void atomicStore(OpBuilder &builder, Location loc,
323-
MemRefValue linearizedMemref, Value linearizedIndex,
334+
MemRefValue linearizedMemref, Value storeIdx,
324335
VectorValue valueToStore, Value mask) {
325336
assert(valueToStore.getType().getRank() == 1 && "expected 1-D vector");
326337

327338
// Create an atomic load-modify-write region using
328339
// `memref.generic_atomic_rmw`.
329340
auto atomicOp = builder.create<memref::GenericAtomicRMWOp>(
330-
loc, linearizedMemref, ValueRange{linearizedIndex});
341+
loc, linearizedMemref, ValueRange{storeIdx});
331342
Value origValue = atomicOp.getCurrentValue();
332343

333344
OpBuilder::InsertionGuard guard(builder);
@@ -338,30 +349,30 @@ static void atomicStore(OpBuilder &builder, Location loc,
338349
auto oneElemVecType = VectorType::get({1}, origValue.getType());
339350
Value origVecValue = builder.create<vector::FromElementsOp>(
340351
loc, oneElemVecType, ValueRange{origValue});
341-
origVecValue = builder.create<vector::BitCastOp>(loc, valueToStore.getType(),
342-
origVecValue);
343352

344353
// Construct the final masked value and yield it.
345-
Value maskedValue = selectAndCast(builder, loc, oneElemVecType, mask,
346-
valueToStore, origVecValue);
354+
Value maskedValue =
355+
downcastSelectAndUpcast(builder, loc, valueToStore.getType(),
356+
oneElemVecType, mask, valueToStore, origVecValue);
347357
auto scalarMaskedValue =
348358
builder.create<vector::ExtractOp>(loc, maskedValue, 0);
349359
builder.create<memref::AtomicYieldOp>(loc, scalarMaskedValue);
350360
}
351361

352-
/// Extract `sliceNumElements` from source `vector` at `sliceOffset`,
353-
/// and insert it into an empty vector at offset `byteOffset`.
362+
/// Extract `sliceNumElements` from source `vector` at `extractOffset`,
363+
/// and insert it into an empty vector at `insertOffset`.
354364
/// Inputs:
355-
/// vector = |1|2|3|4| : vector<4xi2>
356-
/// sliceOffset = 1
365+
/// vec_in = |0|1|2|3| : vector<4xi2>
366+
/// extractOffset = 1
357367
/// sliceNumElements = 2
358-
/// byteOffset = 2
368+
/// insertOffset = 2
359369
/// Output:
360-
/// vector = |0|0|2|3| : vector<4xi2>
370+
/// vec_out = |0|0|1|2| : vector<4xi2>
361371
static Value extractSliceIntoByte(ConversionPatternRewriter &rewriter,
362372
Location loc, VectorValue vector,
363-
int64_t sliceOffset, int64_t sliceNumElements,
364-
int64_t byteOffset) {
373+
int64_t extractOffset,
374+
int64_t sliceNumElements,
375+
int64_t insertOffset) {
365376
assert(vector.getType().getRank() == 1 && "expected 1-D vector");
366377
auto vectorElementType = vector.getType().getElementType();
367378
assert(
@@ -374,9 +385,9 @@ static Value extractSliceIntoByte(ConversionPatternRewriter &rewriter,
374385
loc, VectorType::get({scale}, vectorElementType),
375386
rewriter.getZeroAttr(VectorType::get({scale}, vectorElementType)));
376387
auto extracted = staticallyExtractSubvector(rewriter, loc, vector,
377-
sliceOffset, sliceNumElements);
388+
extractOffset, sliceNumElements);
378389
return staticallyInsertSubvector(rewriter, loc, extracted, emptyByteVector,
379-
byteOffset);
390+
insertOffset);
380391
}
381392

382393
namespace {

mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -361,25 +361,27 @@ func.func @vector_maskedload_i2_constant_mask_unaligned(%passthru: vector<5xi2>)
361361
/// vector.store
362362
///----------------------------------------------------------------------------------------
363363

364-
func.func @vector_store_i2_const_index_two_atomic(%arg0: vector<3xi2>) {
365-
%0 = memref.alloc() : memref<3x3xi2>
364+
func.func @vector_store_i2_const_index_two_atomic_rmw(%arg0: vector<3xi2>) {
365+
%src = memref.alloc() : memref<3x3xi2>
366366
%c0 = arith.constant 0 : index
367367
%c2 = arith.constant 2 : index
368-
vector.store %arg0, %0[%c2, %c0] :memref<3x3xi2>, vector<3xi2>
368+
vector.store %arg0, %src[%c2, %c0] :memref<3x3xi2>, vector<3xi2>
369369
return
370370
}
371371

372372
// In this example, emit 2 atomic RMWs.
373-
// Load from bit [12:18), byte [1:2] of total 3 bytes, both bytes needs rmw.
373+
//
374+
// Note, sizeof(%src) = 18 bits. This is modelled as %src_as_bytes:
375+
// <3xi8> (bits [0, 18) with the input values from %src, and [18, 24) are masked out)
374376

375-
// CHECK-LABEL: func @vector_store_i2_const_index_two_atomic(
377+
// CHECK-LABEL: func @vector_store_i2_const_index_two_atomic_rmw(
376378
// CHECK-SAME: %[[ARG0:.+]]: vector<3xi2>)
377379
// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<3xi8>
378380
// CHECK: %[[C1:.+]] = arith.constant 1 : index
379381
// CHECK: %[[CST:.+]] = arith.constant dense<[false, false, true, true]> : vector<4xi1>
380382
// CHECK: %[[CST_0:.+]] = arith.constant dense<0> : vector<4xi2>
381383

382-
// Part 1 atomic RMW sequence
384+
// Part 1 atomic RMW sequence (load bits [12, 16) from %src_as_bytes[1])
383385
// CHECK: %[[EXTRACT:.+]] = vector.extract_strided_slice %[[ARG0]]
384386
// CHECK-SAME: {offsets = [0], sizes = [2], strides = [1]} : vector<3xi2> to vector<2xi2>
385387
// CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[EXTRACT]], %[[CST_0]]
@@ -393,7 +395,7 @@ func.func @vector_store_i2_const_index_two_atomic(%arg0: vector<3xi2>) {
393395
// CHECK: %[[EXTRACT2:.+]] = vector.extract %[[BITCAST2]][0] : i8 from vector<1xi8>
394396
// CHECK: memref.atomic_yield %[[EXTRACT2]] : i8
395397

396-
// Part 2 atomic RMW sequence
398+
// Part 2 atomic RMW sequence (load bits [16, 18) from %src_as_bytes[2])
397399
// CHECK: %[[ADDR2:.+]] = arith.addi %[[C1]], %[[C1]] : index
398400
// CHECK: %[[EXTRACT3:.+]] = vector.extract_strided_slice %[[ARG0]]
399401
// CHECK-SAME: {offsets = [2], sizes = [1], strides = [1]} : vector<3xi2> to vector<1xi2>
@@ -411,7 +413,7 @@ func.func @vector_store_i2_const_index_two_atomic(%arg0: vector<3xi2>) {
411413

412414
// -----
413415

414-
func.func @vector_store_i2_atomic(%arg0: vector<7xi2>) {
416+
func.func @vector_store_i2_atomic_rmw(%arg0: vector<7xi2>) {
415417
%0 = memref.alloc() : memref<3x7xi2>
416418
%c0 = arith.constant 0 : index
417419
%c1 = arith.constant 1 : index
@@ -420,7 +422,7 @@ func.func @vector_store_i2_atomic(%arg0: vector<7xi2>) {
420422
}
421423

422424
// In this example, emit 2 atomic RMWs and 1 non-atomic store:
423-
// CHECK-LABEL: func @vector_store_i2_atomic(
425+
// CHECK-LABEL: func @vector_store_i2_atomic_rmw(
424426
// CHECK-SAME: %[[ARG0:.+]]: vector<7xi2>)
425427
// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<6xi8>
426428
// CHECK: %[[C1:.+]] = arith.constant 1 : index
@@ -467,7 +469,7 @@ func.func @vector_store_i2_atomic(%arg0: vector<7xi2>) {
467469

468470
// -----
469471

470-
func.func @vector_store_i2_single_atomic(%arg0: vector<1xi2>) {
472+
func.func @vector_store_i2_const_index_one_atomic_rmw(%arg0: vector<1xi2>) {
471473
%0 = memref.alloc() : memref<4x1xi2>
472474
%c0 = arith.constant 0 : index
473475
%c1 = arith.constant 1 : index
@@ -476,7 +478,7 @@ func.func @vector_store_i2_single_atomic(%arg0: vector<1xi2>) {
476478
}
477479

478480
// In this example, only emit 1 atomic store
479-
// CHECK-LABEL: func @vector_store_i2_single_atomic(
481+
// CHECK-LABEL: func @vector_store_i2_const_index_one_atomic_rmw(
480482
// CHECK-SAME: %[[ARG0:.+]]: vector<1xi2>)
481483
// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<1xi8>
482484
// CHECK: %[[C0:.+]] = arith.constant 0 : index

0 commit comments

Comments
 (0)