diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp index 6ec7544767c52..25cdf72a658a8 100644 --- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp @@ -148,6 +148,7 @@ struct ConvertBuiltin { bool IsSaturated; bool IsRounded; bool IsBfloat16; + bool IsTF32; FPRoundingMode::FPRoundingMode RoundingMode; }; @@ -230,6 +231,7 @@ std::string lookupBuiltinNameHelper(StringRef DemangledCall, // - "__spirv_SubgroupImageMediaBlockReadINTEL" // - "__spirv_SubgroupImageMediaBlockWriteINTEL" // - "__spirv_Convert" + // - "__spirv_Round" // - "__spirv_UConvert" // - "__spirv_SConvert" // - "__spirv_FConvert" @@ -242,7 +244,7 @@ std::string lookupBuiltinNameHelper(StringRef DemangledCall, "SDotKHR|SUDotKHR|SDotAccSatKHR|UDotAccSatKHR|SUDotAccSatKHR|" "ReadClockKHR|SubgroupBlockReadINTEL|SubgroupImageBlockReadINTEL|" "SubgroupImageMediaBlockReadINTEL|SubgroupImageMediaBlockWriteINTEL|" - "Convert|" + "Convert|Round|" "UConvert|SConvert|FConvert|SatConvert)[^_]*)(_R[^_]*_?(\\w+)?.*)?"); std::smatch Match; if (std::regex_match(BuiltinName, Match, SpvWithR) && Match.size() > 1) { @@ -697,7 +699,8 @@ static bool buildAtomicStoreInst(const SPIRV::IncomingCall *Call, MachineIRBuilder &MIRBuilder, SPIRVGlobalRegistry *GR) { if (Call->isSpirvOp()) - return buildOpFromWrapper(MIRBuilder, SPIRV::OpAtomicStore, Call, Register(0)); + return buildOpFromWrapper(MIRBuilder, SPIRV::OpAtomicStore, Call, + Register(0)); Register ScopeRegister = buildConstantIntReg32(SPIRV::Scope::Device, MIRBuilder, GR); @@ -2677,8 +2680,20 @@ static bool generateConvertInst(const StringRef DemangledCall, } } else if (GR->isScalarOrVectorOfType(Call->ReturnRegister, SPIRV::OpTypeFloat)) { - // Float -> Float - Opcode = SPIRV::OpFConvert; + if (Builtin->IsTF32) { + const auto *ST = static_cast( + &MIRBuilder.getMF().getSubtarget()); + if (!ST->canUseExtension( + SPIRV::Extension::SPV_INTEL_tensor_float32_conversion)) + NeedExtMsg = "SPV_INTEL_tensor_float32_conversion"; + IsRightComponentsNumber = + GR->getScalarOrVectorComponentCount(Call->Arguments[0]) == + GR->getScalarOrVectorComponentCount(Call->ReturnRegister); + Opcode = SPIRV::OpRoundFToTF32INTEL; + } else { + // Float -> Float + Opcode = SPIRV::OpFConvert; + } } } diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.td b/llvm/lib/Target/SPIRV/SPIRVBuiltins.td index ea78dcd135267..d08560bb6565a 100644 --- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.td +++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.td @@ -1461,6 +1461,8 @@ class ConvertBuiltin { bit IsRounded = !not(!eq(!find(name, "_rt"), -1)); bit IsBfloat16 = !or(!not(!eq(!find(name, "BF16"), -1)), !not(!eq(!find(name, "bfloat16"), -1))); + bit IsTF32 = !or(!not(!eq(!find(name, "TF32"), -1)), + !not(!eq(!find(name, "tensor_float32"), -1))); FPRoundingMode RoundingMode = !cond(!not(!eq(!find(name, "_rte"), -1)) : RTE, !not(!eq(!find(name, "_rtz"), -1)) : RTZ, !not(!eq(!find(name, "_rtp"), -1)) : RTP, @@ -1472,7 +1474,7 @@ class ConvertBuiltin { def ConvertBuiltins : GenericTable { let FilterClass = "ConvertBuiltin"; let Fields = ["Name", "Set", "IsDestinationSigned", "IsSaturated", - "IsRounded", "IsBfloat16", "RoundingMode"]; + "IsRounded", "IsBfloat16", "IsTF32", "RoundingMode"]; string TypeOf_Set = "InstructionSet"; string TypeOf_RoundingMode = "FPRoundingMode"; } @@ -1556,6 +1558,25 @@ foreach conv = ["FToBF16INTEL", "BF16ToFINTEL"] in { def : ConvertBuiltin; } +// cl_intel_tensor_float32_conversions / SPV_INTEL_tensor_float32_conversion +// Multiclass used to define at the same time both a demangled builtin record +// and a corresponding convert builtin record. +multiclass DemangledTF32RoundBuiltin { + // Create records for scalar and vector conversions. + foreach i = ["", "2", "3", "4", "8", "16"] in { + def : DemangledBuiltin; + def : ConvertBuiltin; + } +} + +defm : DemangledTF32RoundBuiltin<"tensor_float32", "_as_float">; +defm : DemangledTF32RoundBuiltin<"as_tensor_float32", "_float">; + +foreach conv = ["FToTF32INTEL"] in { + def : DemangledBuiltin; + def : ConvertBuiltin; +} + //===----------------------------------------------------------------------===// // Class defining a vector data load/store builtin record used for lowering // into OpExtInst instruction. diff --git a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp index 2726203d253ad..d9265f498973e 100644 --- a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp @@ -102,7 +102,9 @@ static const std::map> SPIRV::Extension::Extension::SPV_INTEL_2d_block_io}, {"SPV_INTEL_int4", SPIRV::Extension::Extension::SPV_INTEL_int4}, {"SPV_KHR_float_controls2", - SPIRV::Extension::Extension::SPV_KHR_float_controls2}}; + SPIRV::Extension::Extension::SPV_KHR_float_controls2}, + {"SPV_INTEL_tensor_float32_conversion", + SPIRV::Extension::Extension::SPV_INTEL_tensor_float32_conversion}}; bool SPIRVExtensionsParser::parse(cl::Option &O, StringRef ArgName, StringRef ArgValue, diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td index 049ba0275f223..f0b938d681dba 100644 --- a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td +++ b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td @@ -445,6 +445,9 @@ def OpCrossWorkgroupCastToPtrINTEL : UnOp<"OpCrossWorkgroupCastToPtrINTEL", 5938 def OpConvertFToBF16INTEL : UnOp<"OpConvertFToBF16INTEL", 6116>; def OpConvertBF16ToFINTEL : UnOp<"OpConvertBF16ToFINTEL", 6117>; +// SPV_INTEL_tensor_float32_conversion +def OpRoundFToTF32INTEL : UnOp<"OpRoundFToTF32INTEL", 6426>; + // 3.42.12 Composite Instructions def OpVectorExtractDynamic: Op<77, (outs ID:$res), (ins TYPE:$type, vID:$vec, ID:$idx), diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp index ad976e5288927..0cd9d7882a52a 100644 --- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp @@ -1564,6 +1564,13 @@ void addInstrRequirements(const MachineInstr &MI, Reqs.addCapability(SPIRV::Capability::BFloat16ConversionINTEL); } break; + case SPIRV::OpRoundFToTF32INTEL: + if (ST.canUseExtension( + SPIRV::Extension::SPV_INTEL_tensor_float32_conversion)) { + Reqs.addExtension(SPIRV::Extension::SPV_INTEL_tensor_float32_conversion); + Reqs.addCapability(SPIRV::Capability::TensorFloat32RoundingINTEL); + } + break; case SPIRV::OpVariableLengthArrayINTEL: case SPIRV::OpSaveMemoryINTEL: case SPIRV::OpRestoreMemoryINTEL: diff --git a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td index 548e9b717c161..614e83ae9b286 100644 --- a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td +++ b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td @@ -320,6 +320,7 @@ defm SPV_INTEL_subgroup_matrix_multiply_accumulate : ExtensionOperand<121>; defm SPV_INTEL_2d_block_io : ExtensionOperand<122>; defm SPV_INTEL_int4 : ExtensionOperand<123>; defm SPV_KHR_float_controls2 : ExtensionOperand<124>; +defm SPV_INTEL_tensor_float32_conversion : ExtensionOperand<125>; //===----------------------------------------------------------------------===// // Multiclass used to define Capabilities enum values and at the same time @@ -529,6 +530,7 @@ defm Subgroup2DBlockTransformINTEL : CapabilityOperand<6229, 0, 0, [SPV_INTEL_2d defm Subgroup2DBlockTransposeINTEL : CapabilityOperand<6230, 0, 0, [SPV_INTEL_2d_block_io], [Subgroup2DBlockIOINTEL]>; defm Int4TypeINTEL : CapabilityOperand<5112, 0, 0, [SPV_INTEL_int4], []>; defm Int4CooperativeMatrixINTEL : CapabilityOperand<5114, 0, 0, [SPV_INTEL_int4], [Int4TypeINTEL, CooperativeMatrixKHR]>; +defm TensorFloat32RoundingINTEL : CapabilityOperand<6425, 0, 0, [SPV_INTEL_tensor_float32_conversion], []>; //===----------------------------------------------------------------------===// // Multiclass used to define SourceLanguage enum values and at the same time diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_tensor_float32_conversion/tf32-conv-negative1.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_tensor_float32_conversion/tf32-conv-negative1.ll new file mode 100644 index 0000000000000..fa708ab022a85 --- /dev/null +++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_tensor_float32_conversion/tf32-conv-negative1.ll @@ -0,0 +1,12 @@ +; RUN: not llc -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_INTEL_tensor_float32_conversion %s -o %t.spvt 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR +; CHECK-ERROR: result and argument must have the same number of components + +target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64" +target triple = "spir64-unknown-unknown" + +define spir_func void @test(<8 x float> %in) { + %res = tail call spir_func float @_Z25__spirv_RoundFToTF32INTELDv8_f(<8 x float> %in) + ret void +} + +declare spir_func float @_Z25__spirv_RoundFToTF32INTELDv8_f(<8 x float>) diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_tensor_float32_conversion/tf32-conv-negative2.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_tensor_float32_conversion/tf32-conv-negative2.ll new file mode 100644 index 0000000000000..630b2fdd7696c --- /dev/null +++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_tensor_float32_conversion/tf32-conv-negative2.ll @@ -0,0 +1,12 @@ +; RUN: not llc -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_INTEL_tensor_float32_conversion %s -o %t.spvt 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR +; CHECK-ERROR: result and argument must have the same number of components + +target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64" +target triple = "spir64-unknown-unknown" + +define spir_func void @test(<8 x float> %in) { + %res = tail call spir_func <4 x float> @_Z25__spirv_RoundFToTF32INTELDv8_f(<8 x float> %in) + ret void +} + +declare spir_func <4 x float> @_Z25__spirv_RoundFToTF32INTELDv8_f(<8 x float>) diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_tensor_float32_conversion/tf32-conv.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_tensor_float32_conversion/tf32-conv.ll new file mode 100644 index 0000000000000..dcad78d17bff7 --- /dev/null +++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_tensor_float32_conversion/tf32-conv.ll @@ -0,0 +1,62 @@ +; RUN: llc -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_INTEL_tensor_float32_conversion %s -o - | FileCheck %s +; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_INTEL_tensor_float32_conversion %s -o - -filetype=obj | spirv-val %} + +; RUN: not llc -O0 -mtriple=spirv32-unknown-unknown %s -o %t.spvt 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR +; CHECK-ERROR: the builtin requires the following SPIR-V extension: SPV_INTEL_tensor_float32_conversion + +; CHECK: OpCapability TensorFloat32RoundingINTEL +; CHECK: OpExtension "SPV_INTEL_tensor_float32_conversion" + +; CHECK-DAG: %[[VoidTy:.*]] = OpTypeVoid +; CHECK-DAG: %[[FP32Ty:.*]] = OpTypeFloat 32 +; CHECK-DAG: %[[VecFloat2:.*]] = OpTypeVector %[[FP32Ty]] 2 +; CHECK-DAG: %[[VecFloat3:.*]] = OpTypeVector %[[FP32Ty]] 3 +; CHECK-DAG: %[[VecFloat4:.*]] = OpTypeVector %[[FP32Ty]] 4 +; CHECK-DAG: %[[VecFloat8:.*]] = OpTypeVector %[[FP32Ty]] 8 +; CHECK-DAG: %[[VecFloat16:.*]] = OpTypeVector %[[FP32Ty]] 16 +; CHECK-DAG: %[[FloatConstId:.*]] = OpConstant %[[FP32Ty]] 1.5 + +; CHECK: OpFunction %[[VoidTy]] +; CHECK: %[[FP32ValId:.*]] = OpFunctionParameter %[[FP32Ty]] +; CHECK: %[[FP32v8ValId:.*]] = OpFunctionParameter %[[VecFloat8]] +; CHECK: OpRoundFToTF32INTEL %[[FP32Ty]] %[[FP32ValId]] +; CHECK: OpRoundFToTF32INTEL %[[VecFloat8]] %[[FP32v8ValId]] +; CHECK: OpRoundFToTF32INTEL %[[FP32Ty]] %[[FloatConstId]] + +; CHECK: OpRoundFToTF32INTEL %[[FP32Ty]] +; CHECK: OpRoundFToTF32INTEL %[[VecFloat2]] +; CHECK: OpRoundFToTF32INTEL %[[VecFloat3]] +; CHECK: OpRoundFToTF32INTEL %[[VecFloat4]] +; CHECK: OpRoundFToTF32INTEL %[[VecFloat8]] +; CHECK: OpRoundFToTF32INTEL %[[VecFloat16]] + +target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64" +target triple = "spir64-unknown-unknown" + +define spir_func void @test(float %a, <8 x float> %in) { + %res1 = tail call spir_func float @_Z25__spirv_RoundFToTF32INTELf(float %a) + %res2 = tail call spir_func <8 x float> @_Z25__spirv_RoundFToTF32INTELDv8_f(<8 x float> %in) + %res3 = tail call spir_func float @_Z25__spirv_RoundFToTF32INTELf(float 1.500000e+00) + ret void +} + +declare spir_func float @_Z25__spirv_RoundFToTF32INTELf(float) +declare spir_func <8 x float> @_Z25__spirv_RoundFToTF32INTELDv8_f(<8 x float>) + +define dso_local spir_kernel void @test_ocl(float %a) { +entry: + %res4 = call spir_func float @_Z35intel_round_as_tensor_float32_floatt(float 0.000000e+00) + %res5 = call spir_func <2 x float> @_Z37intel_round_as_tensor_float322_float2Dv2_t(<2 x float> zeroinitializer) + %res6 = call spir_func <3 x float> @_Z37intel_round_as_tensor_float323_float3Dv3_t(<3 x float> zeroinitializer) + %res7 = call spir_func <4 x float> @_Z37intel_round_as_tensor_float324_float4Dv4_t(<4 x float> zeroinitializer) + %res8 = call spir_func <8 x float> @_Z37intel_round_as_tensor_float328_float8Dv8_t(<8 x float> zeroinitializer) + %res9 = call spir_func <16 x float> @_Z39intel_round_as_tensor_float3216_float16Dv16_t(<16 x float> zeroinitializer) + ret void +} + +declare spir_func float @_Z35intel_round_as_tensor_float32_floatt(float) +declare spir_func <2 x float> @_Z37intel_round_as_tensor_float322_float2Dv2_t(<2 x float>) +declare spir_func <3 x float> @_Z37intel_round_as_tensor_float323_float3Dv3_t(<3 x float>) +declare spir_func <4 x float> @_Z37intel_round_as_tensor_float324_float4Dv4_t(<4 x float>) +declare spir_func <8 x float> @_Z37intel_round_as_tensor_float328_float8Dv8_t(<8 x float>) +declare spir_func <16 x float> @_Z39intel_round_as_tensor_float3216_float16Dv16_t(<16 x float>)