@@ -351,39 +351,23 @@ struct SchedBarrierOpLowering : public ConvertOpToLLVMPattern<SchedBarrierOp> {
351351
352352} // namespace
353353
354- // / If `input` is a vector of bytes, concatentate those bytes in little-endian
355- // / order to form a single integer of size 8 * [vector length]. This works
356- // / around a wart in the AMDGPU intrinsics where operations that logically take
357- // / vectors of bytes instead integers. Since we do not want to expose this
358- // / implementation detail to MLIR, we correct for it here.
354+ // / Converts a MFMA vector operand from MLIR AMDGPU dialect convention to ROCDL
355+ // / and LLVM AMDGPU intrinsics convention.
359356// /
360- // / In addition, convert vectors of LLVM bfloats to vectors of i16, since AMDGPU
361- // / MFMA intrinsics pre-date the bfloat type.
362- static Value mfmaConcatIfNeeded (ConversionPatternRewriter &rewriter,
363- Location loc, Value input) {
357+ // / Specifically:
358+ // / 1. If `input` is a vector of N bytes, bitcast it to a (N * 8)-bit integer.
359+ // / 2. If the element type is bfloat16, bitcast it to i16.
360+ static Value convertMFMAVectorOperand (ConversionPatternRewriter &rewriter,
361+ Location loc, Value input) {
364362 Type inputType = input.getType ();
365363 if (auto vectorType = dyn_cast<VectorType>(inputType)) {
366364 if (vectorType.getElementType ().isBF16 ())
367365 return rewriter.create <LLVM::BitcastOp>(
368366 loc, vectorType.clone (rewriter.getI16Type ()), input);
369-
370- if (!vectorType.getElementType ().isInteger (8 ))
371- return input;
372- int64_t numBytes = vectorType.getNumElements ();
373- Type destType = rewriter.getIntegerType (numBytes * 8 );
374- Value result = rewriter.create <LLVM::ConstantOp>(
375- loc, destType, rewriter.getIntegerAttr (destType, 0 ));
376- for (int64_t i = 0 ; i < numBytes; ++i) {
377- Value idxConst = createI32Constant (rewriter, loc, i);
378- Value element =
379- rewriter.create <LLVM::ExtractElementOp>(loc, input, idxConst);
380- Value extended = rewriter.create <LLVM::ZExtOp>(loc, destType, element);
381- Value shiftConst = rewriter.create <LLVM::ConstantOp>(
382- loc, destType, rewriter.getIntegerAttr (destType, i * 8 ));
383- Value shifted = rewriter.create <LLVM::ShlOp>(loc, extended, shiftConst);
384- result = rewriter.create <LLVM::OrOp>(loc, result, shifted);
367+ if (vectorType.getElementType ().isInteger (8 )) {
368+ return rewriter.create <LLVM::BitcastOp>(
369+ loc, rewriter.getIntegerType (vectorType.getNumElements () * 8 ), input);
385370 }
386- return result;
387371 }
388372 return input;
389373}
@@ -656,8 +640,8 @@ struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> {
656640 OperationState loweredOp (loc, *maybeIntrinsic);
657641 loweredOp.addTypes (intrinsicOutType);
658642 loweredOp.addOperands (
659- {mfmaConcatIfNeeded (rewriter, loc, adaptor.getSourceA ()),
660- mfmaConcatIfNeeded (rewriter, loc, adaptor.getSourceB ()),
643+ {convertMFMAVectorOperand (rewriter, loc, adaptor.getSourceA ()),
644+ convertMFMAVectorOperand (rewriter, loc, adaptor.getSourceB ()),
661645 adaptor.getDestC (), createI32Constant (rewriter, loc, op.getCbsz ()),
662646 createI32Constant (rewriter, loc, op.getAbid ()),
663647 createI32Constant (rewriter, loc, getBlgpField)});
0 commit comments