@@ -296,38 +296,49 @@ static VectorValue emulatedVectorLoad(OpBuilder &rewriter, Location loc,
296296 newLoad);
297297}
298298
299- // / Selects values from two sources based on a mask, and casts the result to a
300- // / new type.
301- static Value selectAndCast (OpBuilder &builder, Location loc,
302- VectorType castIntoType, Value mask, Value trueValue,
303- Value falseValue) {
304- Value maskedValue =
299+ // / Downcast two values to `downcastType`, then select values
300+ // / based on `mask`, and casts the result to `upcastType`.
301+ static Value downcastSelectAndUpcast (OpBuilder &builder, Location loc,
302+ VectorType downcastType,
303+ VectorType upcastType, Value mask,
304+ Value trueValue, Value falseValue) {
305+ assert (
306+ downcastType.getNumElements () * downcastType.getElementTypeBitWidth () ==
307+ upcastType.getNumElements () * upcastType.getElementTypeBitWidth () &&
308+ " expected upcastType size to be twice the size of downcastType" );
309+ if (trueValue.getType () != downcastType)
310+ trueValue = builder.create <vector::BitCastOp>(loc, downcastType, trueValue);
311+ if (falseValue.getType () != downcastType)
312+ falseValue =
313+ builder.create <vector::BitCastOp>(loc, downcastType, falseValue);
314+ Value selectedType =
305315 builder.create <arith::SelectOp>(loc, mask, trueValue, falseValue);
306- return builder.create <vector::BitCastOp>(loc, castIntoType, maskedValue);
316+ // Upcast the selected value to the new type.
317+ return builder.create <vector::BitCastOp>(loc, upcastType, selectedType);
307318}
308319
309320// / Emits `memref.generic_atomic_rmw` op to store a subbyte-sized value to a
310- // / byte in memory , with a mask. The `valueToStore` is a vector of subbyte-sized
311- // / elements, with size of 8 bits, and the mask is used to select which elements
312- // / to store.
321+ // / byte in `linearizedMemref` , with a mask. The `valueToStore` is a vector of
322+ // / subbyte-sized elements, with size of 8 bits, and the mask is used to select
323+ // / which elements to store.
313324// /
314325// / Inputs:
315326// / linearizedMemref = |2|2|2|2| : <4xi2> (<1xi8>)
316- // / linearizedIndex = 2
327+ // / storeIdx = 2
317328// / valueToStore = |3|3|3|3| : vector<4xi2>
318329// / mask = |0|0|1|1| : vector<4xi1>
319330// /
320331// / Result:
321332// / linearizedMemref = |2|2|3|3| : <4xi2> (<1xi8>)
322333static void atomicStore (OpBuilder &builder, Location loc,
323- MemRefValue linearizedMemref, Value linearizedIndex ,
334+ MemRefValue linearizedMemref, Value storeIdx ,
324335 VectorValue valueToStore, Value mask) {
325336 assert (valueToStore.getType ().getRank () == 1 && " expected 1-D vector" );
326337
327338 // Create an atomic load-modify-write region using
328339 // `memref.generic_atomic_rmw`.
329340 auto atomicOp = builder.create <memref::GenericAtomicRMWOp>(
330- loc, linearizedMemref, ValueRange{linearizedIndex });
341+ loc, linearizedMemref, ValueRange{storeIdx });
331342 Value origValue = atomicOp.getCurrentValue ();
332343
333344 OpBuilder::InsertionGuard guard (builder);
@@ -338,30 +349,30 @@ static void atomicStore(OpBuilder &builder, Location loc,
338349 auto oneElemVecType = VectorType::get ({1 }, origValue.getType ());
339350 Value origVecValue = builder.create <vector::FromElementsOp>(
340351 loc, oneElemVecType, ValueRange{origValue});
341- origVecValue = builder.create <vector::BitCastOp>(loc, valueToStore.getType (),
342- origVecValue);
343352
344353 // Construct the final masked value and yield it.
345- Value maskedValue = selectAndCast (builder, loc, oneElemVecType, mask,
346- valueToStore, origVecValue);
354+ Value maskedValue =
355+ downcastSelectAndUpcast (builder, loc, valueToStore.getType (),
356+ oneElemVecType, mask, valueToStore, origVecValue);
347357 auto scalarMaskedValue =
348358 builder.create <vector::ExtractOp>(loc, maskedValue, 0 );
349359 builder.create <memref::AtomicYieldOp>(loc, scalarMaskedValue);
350360}
351361
352- // / Extract `sliceNumElements` from source `vector` at `sliceOffset `,
353- // / and insert it into an empty vector at offset `byteOffset `.
362+ // / Extract `sliceNumElements` from source `vector` at `extractOffset `,
363+ // / and insert it into an empty vector at `insertOffset `.
354364// / Inputs:
355- // / vector = |1|2|3|4 | : vector<4xi2>
356- // / sliceOffset = 1
365+ // / vec_in = |0| 1|2|3| : vector<4xi2>
366+ // / extractOffset = 1
357367// / sliceNumElements = 2
358- // / byteOffset = 2
368+ // / insertOffset = 2
359369// / Output:
360- // / vector = |0|0|2|3 | : vector<4xi2>
370+ // / vec_out = |0|0|1|2 | : vector<4xi2>
361371static Value extractSliceIntoByte (ConversionPatternRewriter &rewriter,
362372 Location loc, VectorValue vector,
363- int64_t sliceOffset, int64_t sliceNumElements,
364- int64_t byteOffset) {
373+ int64_t extractOffset,
374+ int64_t sliceNumElements,
375+ int64_t insertOffset) {
365376 assert (vector.getType ().getRank () == 1 && " expected 1-D vector" );
366377 auto vectorElementType = vector.getType ().getElementType ();
367378 assert (
@@ -374,9 +385,9 @@ static Value extractSliceIntoByte(ConversionPatternRewriter &rewriter,
374385 loc, VectorType::get ({scale}, vectorElementType),
375386 rewriter.getZeroAttr (VectorType::get ({scale}, vectorElementType)));
376387 auto extracted = staticallyExtractSubvector (rewriter, loc, vector,
377- sliceOffset , sliceNumElements);
388+ extractOffset , sliceNumElements);
378389 return staticallyInsertSubvector (rewriter, loc, extracted, emptyByteVector,
379- byteOffset );
390+ insertOffset );
380391}
381392
382393namespace {
0 commit comments