@@ -385,6 +385,7 @@ static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter,
385385 Location loc,
386386 const TypeConverter *typeConverter,
387387 bool isUnsigned, Value llvmInput,
388+ Value mlirInput,
388389 SmallVector<Value, 4 > &operands) {
389390 Type inputType = llvmInput.getType ();
390391 auto vectorType = dyn_cast<VectorType>(inputType);
@@ -398,23 +399,29 @@ static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter,
398399 return ;
399400 }
400401
402+ // We need to check the type of the input before conversion to properly test
403+ // for int8. This is because, in LLVM, fp8 type is converted to int8, so the
404+ // fp8/int8 information is lost during the conversion process.
405+ auto mlirInputType = cast<VectorType>(mlirInput.getType ());
406+ bool isInputInt8 = mlirInputType.getElementType ().isInteger (8 );
407+ if (isInputInt8) {
408+ // if element type is 8-bit signed or unsigned, ignore the isUnsigned flag
409+ bool localIsUnsigned = isUnsigned;
410+ if (elemType.isUnsignedInteger (8 )) {
411+ localIsUnsigned = true ;
412+ } else if (elemType.isSignedInteger (8 )) {
413+ localIsUnsigned = false ;
414+ }
415+ Value sign = createI1Constant (rewriter, loc, !localIsUnsigned);
416+ operands.push_back (sign);
417+ }
418+
401419 int64_t numBytes = vectorType.getNumElements ();
402420 Type i32 = rewriter.getI32Type ();
403421 VectorType vectorType32bits = VectorType::get (numBytes * 8 / 32 , i32 );
404422 auto llvmVectorType32bits = typeConverter->convertType (vectorType32bits);
405-
406423 Value result = rewriter.createOrFold <LLVM::BitcastOp>(
407424 loc, llvmVectorType32bits, llvmInput);
408-
409- // if element type is 8-bit signed or unsigned, ignore the isUnsigned flag
410- bool localIsUnsigned = isUnsigned;
411- if (elemType.isUnsignedInteger (8 )) {
412- localIsUnsigned = true ;
413- } else if (elemType.isSignedInteger (8 )) {
414- localIsUnsigned = false ;
415- }
416- Value sign = createI1Constant (rewriter, loc, !localIsUnsigned);
417- operands.push_back (sign);
418425 operands.push_back (result);
419426}
420427
@@ -590,18 +597,20 @@ static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
590597 auto elemSourceType = sourceVectorType.getElementType ();
591598 auto elemDestType = destVectorType.getElementType ();
592599
593- if (elemSourceType.isF16 () && elemDestType.isF32 ()) {
600+ if (elemSourceType.isF16 () && elemDestType.isF32 ())
594601 return ROCDL::wmma_f32_16x16x16_f16::getOperationName ();
595- }
596- if (elemSourceType.isBF16 () && elemDestType.isF32 ()) {
602+ if (elemSourceType.isBF16 () && elemDestType.isF32 ())
597603 return ROCDL::wmma_f32_16x16x16_bf16::getOperationName ();
598- } else if (elemSourceType.isF16 () && elemDestType.isF16 ()) {
604+ if (elemSourceType.isF16 () && elemDestType.isF16 ())
599605 return ROCDL::wmma_f16_16x16x16_f16::getOperationName ();
600- } else if (elemSourceType.isBF16 () && elemDestType.isBF16 ()) {
606+ if (elemSourceType.isBF16 () && elemDestType.isBF16 ())
601607 return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName ();
602- } else if (elemSourceType.isInteger (8 ) && elemDestType.isInteger (32 )) {
608+ if (elemSourceType.isInteger (8 ) && elemDestType.isInteger (32 ))
603609 return ROCDL::wmma_i32_16x16x16_iu8::getOperationName ();
604- }
610+ if (elemSourceType.isFloat8E4M3FN () && elemDestType.isF32 ())
611+ return ROCDL::wmma_f32_16x16x16_fp8::getOperationName ();
612+ if (elemSourceType.isFloat8E5M2 () && elemDestType.isF32 ())
613+ return ROCDL::wmma_f32_16x16x16_bf8::getOperationName ();
605614 return std::nullopt ;
606615}
607616
@@ -662,8 +671,8 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
662671 Location loc = op.getLoc ();
663672 Type outType = typeConverter->convertType (op.getDestD ().getType ());
664673
665- if (chipset.majorVersion != 11 )
666- return op->emitOpError (" WMMA only supported on gfx11" );
674+ if (chipset.majorVersion != 11 && chipset. majorVersion != 12 )
675+ return op->emitOpError (" WMMA only supported on gfx11 and gfx12 " );
667676
668677 std::optional<StringRef> maybeIntrinsic = wmmaOpToIntrinsic (op, chipset);
669678
@@ -675,9 +684,9 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
675684
676685 SmallVector<Value, 4 > operands;
677686 wmmaPushInputOperand (rewriter, loc, typeConverter, op.getUnsignedA (),
678- adaptor.getSourceA (), operands);
687+ adaptor.getSourceA (), op. getSourceA (), operands);
679688 wmmaPushInputOperand (rewriter, loc, typeConverter, op.getUnsignedB (),
680- adaptor.getSourceB (), operands);
689+ adaptor.getSourceB (), op. getSourceB (), operands);
681690 wmmaPushOutputOperand (rewriter, loc, typeConverter, adaptor.getDestC (),
682691 op.getSubwordOffset (), op.getClamp (), operands);
683692
0 commit comments