From 3749aa623d613e32f0cc6691faac1cdde6a8ea85 Mon Sep 17 00:00:00 2001 From: "Zhang, Yixing" Date: Tue, 22 Jul 2025 12:14:29 -0700 Subject: [PATCH 1/6] draft implementation for supporting SPV_INTEL_tensor_float32_conversion --- llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp | 17 +++++++++++++-- llvm/lib/Target/SPIRV/SPIRVBuiltins.td | 21 ++++++++++++++++++- llvm/lib/Target/SPIRV/SPIRVInstrInfo.td | 5 ++++- llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp | 6 ++++++ .../lib/Target/SPIRV/SPIRVSymbolicOperands.td | 2 ++ 5 files changed, 47 insertions(+), 4 deletions(-) diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp index 6ec7544767c52..1c7c1750af1c9 100644 --- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp @@ -148,6 +148,7 @@ struct ConvertBuiltin { bool IsSaturated; bool IsRounded; bool IsBfloat16; + bool IsTF32; FPRoundingMode::FPRoundingMode RoundingMode; }; @@ -2677,8 +2678,20 @@ static bool generateConvertInst(const StringRef DemangledCall, } } else if (GR->isScalarOrVectorOfType(Call->ReturnRegister, SPIRV::OpTypeFloat)) { - // Float -> Float - Opcode = SPIRV::OpFConvert; + if(Builtin->IsTF32){ + const auto *ST = static_cast( + &MIRBuilder.getMF().getSubtarget()); + if (!ST->canUseExtension( + SPIRV::Extension::SPV_INTEL_bfloat16_conversion)) + NeedExtMsg = "SPV_INTEL_bfloat16_conversion"; + IsRightComponentsNumber = + GR->getScalarOrVectorComponentCount(Call->Arguments[0]) == + GR->getScalarOrVectorComponentCount(Call->ReturnRegister); + Opcode = SPIRV::OpRoundFToTF32INTEL; + } else { + Float -> Float + Opcode = SPIRV::OpFConvert; + } } } diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.td b/llvm/lib/Target/SPIRV/SPIRVBuiltins.td index ea78dcd135267..326109c9fdff4 100644 --- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.td +++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.td @@ -1461,6 +1461,7 @@ class ConvertBuiltin { bit IsRounded = !not(!eq(!find(name, "_rt"), -1)); bit IsBfloat16 = !or(!not(!eq(!find(name, "BF16"), -1)), !not(!eq(!find(name, "bfloat16"), -1))); + bit IsTF32 = !not(!eq(!find(name, "TF32"), -1)); FPRoundingMode RoundingMode = !cond(!not(!eq(!find(name, "_rte"), -1)) : RTE, !not(!eq(!find(name, "_rtz"), -1)) : RTZ, !not(!eq(!find(name, "_rtp"), -1)) : RTP, @@ -1472,7 +1473,7 @@ class ConvertBuiltin { def ConvertBuiltins : GenericTable { let FilterClass = "ConvertBuiltin"; let Fields = ["Name", "Set", "IsDestinationSigned", "IsSaturated", - "IsRounded", "IsBfloat16", "RoundingMode"]; + "IsRounded", "IsBfloat16", "IsTF32", "RoundingMode"]; string TypeOf_Set = "InstructionSet"; string TypeOf_RoundingMode = "FPRoundingMode"; } @@ -1556,6 +1557,24 @@ foreach conv = ["FToBF16INTEL", "BF16ToFINTEL"] in { def : ConvertBuiltin; } +// SPV_INTEL_tensor_float32_conversion +// Multiclass used to define at the same time both a demangled builtin records +// and a corresponding convert builtin records. +multiclass DemangledTF32ConvertBuiltin { + // Create records for scalar and vector conversions. + foreach i = ["", "2", "3", "4", "8", "16"] in { + def : DemangledBuiltin; + def : ConvertBuiltin; + } +} + +defm : DemangledTF32ConvertBuiltin<"ConvertFToTF32INTEL">; + +foreach conv = ["FToTF32INTEL"] in { + def : DemangledBuiltin; + def : ConvertBuiltin; +} + //===----------------------------------------------------------------------===// // Class defining a vector data load/store builtin record used for lowering // into OpExtInst instruction. diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td index 049ba0275f223..a04ed6a42c868 100644 --- a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td +++ b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td @@ -441,10 +441,13 @@ def OpBitcast : UnOp<"OpBitcast", 124>; def OpPtrCastToCrossWorkgroupINTEL : UnOp<"OpPtrCastToCrossWorkgroupINTEL", 5934>; def OpCrossWorkgroupCastToPtrINTEL : UnOp<"OpCrossWorkgroupCastToPtrINTEL", 5938>; -// SPV_INTEL_bfloat16_conversion +// SPV_INTEL_tensor_float32_conversion def OpConvertFToBF16INTEL : UnOp<"OpConvertFToBF16INTEL", 6116>; def OpConvertBF16ToFINTEL : UnOp<"OpConvertBF16ToFINTEL", 6117>; +// SPV_INTEL_bfloat16_conversion +def OpRoundFToTF32INTEL : UnOp<"OpRoundFToTF32INTEL", 6426>; + // 3.42.12 Composite Instructions def OpVectorExtractDynamic: Op<77, (outs ID:$res), (ins TYPE:$type, vID:$vec, ID:$idx), diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp index ad976e5288927..c252fc5897518 100644 --- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp @@ -1564,6 +1564,12 @@ void addInstrRequirements(const MachineInstr &MI, Reqs.addCapability(SPIRV::Capability::BFloat16ConversionINTEL); } break; + case SPIRV::OpRoundFToTF32INTEL: + if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_tensor_float32_conversion)) { + Reqs.addExtension(SPIRV::Extension::SPV_INTEL_tensor_float32_conversion); + Reqs.addCapability(SPIRV::Capability::TF32ConversionINTEL); + } + break; case SPIRV::OpVariableLengthArrayINTEL: case SPIRV::OpSaveMemoryINTEL: case SPIRV::OpRestoreMemoryINTEL: diff --git a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td index 548e9b717c161..7b2139a1c84a8 100644 --- a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td +++ b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td @@ -320,6 +320,7 @@ defm SPV_INTEL_subgroup_matrix_multiply_accumulate : ExtensionOperand<121>; defm SPV_INTEL_2d_block_io : ExtensionOperand<122>; defm SPV_INTEL_int4 : ExtensionOperand<123>; defm SPV_KHR_float_controls2 : ExtensionOperand<124>; +defm SPV_INTEL_tensor_float32_conversion : ExtensionOperand<125>; //===----------------------------------------------------------------------===// // Multiclass used to define Capabilities enum values and at the same time @@ -502,6 +503,7 @@ defm VariableLengthArrayINTEL : CapabilityOperand<5817, 0, 0, [SPV_INTEL_variabl defm GroupUniformArithmeticKHR : CapabilityOperand<6400, 0, 0, [SPV_KHR_uniform_group_instructions], []>; defm USMStorageClassesINTEL : CapabilityOperand<5935, 0, 0, [SPV_INTEL_usm_storage_classes], [Kernel]>; defm BFloat16ConversionINTEL : CapabilityOperand<6115, 0, 0, [SPV_INTEL_bfloat16_conversion], []>; +defm TF32ConversionINTEL : CapabilityOperand<6425, 0, 0, [SPV_INTEL_tensor_float32_conversion], []>; defm GlobalVariableHostAccessINTEL : CapabilityOperand<6187, 0, 0, [SPV_INTEL_global_variable_host_access], []>; defm HostAccessINTEL : CapabilityOperand<6188, 0, 0, [SPV_INTEL_global_variable_host_access], []>; defm GlobalVariableFPGADecorationsINTEL : CapabilityOperand<6189, 0, 0, [SPV_INTEL_global_variable_fpga_decorations], []>; From b90fd0c590b67d8a9eb217b82a82ecbf50dacbc3 Mon Sep 17 00:00:00 2001 From: "Zhang, Yixing" Date: Tue, 22 Jul 2025 13:34:02 -0700 Subject: [PATCH 2/6] add tests, finalize the implementation and code cleanup --- llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp | 9 +-- llvm/lib/Target/SPIRV/SPIRVBuiltins.td | 18 +++--- llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp | 4 +- llvm/lib/Target/SPIRV/SPIRVInstrInfo.td | 4 +- llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp | 2 +- .../lib/Target/SPIRV/SPIRVSymbolicOperands.td | 2 +- .../tf32-conv-negative1.ll | 12 ++++ .../tf32-conv-negative2.ll | 12 ++++ .../tf32-conv.ll | 62 +++++++++++++++++++ .../mlir/Dialect/SPIRV/IR/SPIRVBase.td | 17 ++++- .../mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td | 54 ++++++++++++++++ mlir/lib/Dialect/SPIRV/IR/CastOps.cpp | 21 +++++++ mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir | 36 +++++++++++ mlir/test/Target/SPIRV/intel-ext-ops.mlir | 22 +++++++ 14 files changed, 255 insertions(+), 20 deletions(-) create mode 100644 llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_tensor_float32_conversion/tf32-conv-negative1.ll create mode 100644 llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_tensor_float32_conversion/tf32-conv-negative2.ll create mode 100644 llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_tensor_float32_conversion/tf32-conv.ll diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp index 1c7c1750af1c9..03ca2ad1d8fa5 100644 --- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp @@ -231,6 +231,7 @@ std::string lookupBuiltinNameHelper(StringRef DemangledCall, // - "__spirv_SubgroupImageMediaBlockReadINTEL" // - "__spirv_SubgroupImageMediaBlockWriteINTEL" // - "__spirv_Convert" + // - "__spirv_Round" // - "__spirv_UConvert" // - "__spirv_SConvert" // - "__spirv_FConvert" @@ -243,7 +244,7 @@ std::string lookupBuiltinNameHelper(StringRef DemangledCall, "SDotKHR|SUDotKHR|SDotAccSatKHR|UDotAccSatKHR|SUDotAccSatKHR|" "ReadClockKHR|SubgroupBlockReadINTEL|SubgroupImageBlockReadINTEL|" "SubgroupImageMediaBlockReadINTEL|SubgroupImageMediaBlockWriteINTEL|" - "Convert|" + "Convert|Round" "UConvert|SConvert|FConvert|SatConvert)[^_]*)(_R[^_]*_?(\\w+)?.*)?"); std::smatch Match; if (std::regex_match(BuiltinName, Match, SpvWithR) && Match.size() > 1) { @@ -2682,14 +2683,14 @@ static bool generateConvertInst(const StringRef DemangledCall, const auto *ST = static_cast( &MIRBuilder.getMF().getSubtarget()); if (!ST->canUseExtension( - SPIRV::Extension::SPV_INTEL_bfloat16_conversion)) - NeedExtMsg = "SPV_INTEL_bfloat16_conversion"; + SPIRV::Extension::SPV_INTEL_tensor_float32_conversion)) + NeedExtMsg = "SPV_INTEL_tensor_float32_conversion"; IsRightComponentsNumber = GR->getScalarOrVectorComponentCount(Call->Arguments[0]) == GR->getScalarOrVectorComponentCount(Call->ReturnRegister); Opcode = SPIRV::OpRoundFToTF32INTEL; } else { - Float -> Float + // Float -> Float Opcode = SPIRV::OpFConvert; } } diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.td b/llvm/lib/Target/SPIRV/SPIRVBuiltins.td index 326109c9fdff4..49d11bf7c8dca 100644 --- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.td +++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.td @@ -1461,7 +1461,8 @@ class ConvertBuiltin { bit IsRounded = !not(!eq(!find(name, "_rt"), -1)); bit IsBfloat16 = !or(!not(!eq(!find(name, "BF16"), -1)), !not(!eq(!find(name, "bfloat16"), -1))); - bit IsTF32 = !not(!eq(!find(name, "TF32"), -1)); + bit IsTF32 = !or(!not(!eq(!find(name, "TF32"), -1)), + !not(!eq(!find(name, "tensor_float32"), -1))); FPRoundingMode RoundingMode = !cond(!not(!eq(!find(name, "_rte"), -1)) : RTE, !not(!eq(!find(name, "_rtz"), -1)) : RTZ, !not(!eq(!find(name, "_rtp"), -1)) : RTP, @@ -1557,22 +1558,23 @@ foreach conv = ["FToBF16INTEL", "BF16ToFINTEL"] in { def : ConvertBuiltin; } -// SPV_INTEL_tensor_float32_conversion +// cl_intel_tensor_float32_conversions / SPV_INTEL_tensor_float32_conversion // Multiclass used to define at the same time both a demangled builtin records // and a corresponding convert builtin records. -multiclass DemangledTF32ConvertBuiltin { +multiclass DemangledTF32RoundBuiltin { // Create records for scalar and vector conversions. foreach i = ["", "2", "3", "4", "8", "16"] in { - def : DemangledBuiltin; - def : ConvertBuiltin; + def : DemangledBuiltin; + def : ConvertBuiltin; } } -defm : DemangledTF32ConvertBuiltin<"ConvertFToTF32INTEL">; +defm : DemangledTF32RoundBuiltin<"tensor_float32", "_as_float">; +defm : DemangledTF32RoundBuiltin<"as_tensor_float32", "_float">; foreach conv = ["FToTF32INTEL"] in { - def : DemangledBuiltin; - def : ConvertBuiltin; + def : DemangledBuiltin; + def : ConvertBuiltin; } //===----------------------------------------------------------------------===// diff --git a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp index 2726203d253ad..945d3febe0bcf 100644 --- a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp @@ -102,7 +102,9 @@ static const std::map> SPIRV::Extension::Extension::SPV_INTEL_2d_block_io}, {"SPV_INTEL_int4", SPIRV::Extension::Extension::SPV_INTEL_int4}, {"SPV_KHR_float_controls2", - SPIRV::Extension::Extension::SPV_KHR_float_controls2}}; + SPIRV::Extension::Extension::SPV_KHR_float_controls2}, + {"SPV_INTEL_tensor_float32_conversion", + SPIRV::Extension::Extension::SPV_INTEL_tensor_float32_conversion}}; bool SPIRVExtensionsParser::parse(cl::Option &O, StringRef ArgName, StringRef ArgValue, diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td index a04ed6a42c868..f0b938d681dba 100644 --- a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td +++ b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td @@ -441,11 +441,11 @@ def OpBitcast : UnOp<"OpBitcast", 124>; def OpPtrCastToCrossWorkgroupINTEL : UnOp<"OpPtrCastToCrossWorkgroupINTEL", 5934>; def OpCrossWorkgroupCastToPtrINTEL : UnOp<"OpCrossWorkgroupCastToPtrINTEL", 5938>; -// SPV_INTEL_tensor_float32_conversion +// SPV_INTEL_bfloat16_conversion def OpConvertFToBF16INTEL : UnOp<"OpConvertFToBF16INTEL", 6116>; def OpConvertBF16ToFINTEL : UnOp<"OpConvertBF16ToFINTEL", 6117>; -// SPV_INTEL_bfloat16_conversion +// SPV_INTEL_tensor_float32_conversion def OpRoundFToTF32INTEL : UnOp<"OpRoundFToTF32INTEL", 6426>; // 3.42.12 Composite Instructions diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp index c252fc5897518..eac337c3c4246 100644 --- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp @@ -1567,7 +1567,7 @@ void addInstrRequirements(const MachineInstr &MI, case SPIRV::OpRoundFToTF32INTEL: if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_tensor_float32_conversion)) { Reqs.addExtension(SPIRV::Extension::SPV_INTEL_tensor_float32_conversion); - Reqs.addCapability(SPIRV::Capability::TF32ConversionINTEL); + Reqs.addCapability(SPIRV::Capability::TensorFloat32RoundingINTEL); } break; case SPIRV::OpVariableLengthArrayINTEL: diff --git a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td index 7b2139a1c84a8..614e83ae9b286 100644 --- a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td +++ b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td @@ -503,7 +503,6 @@ defm VariableLengthArrayINTEL : CapabilityOperand<5817, 0, 0, [SPV_INTEL_variabl defm GroupUniformArithmeticKHR : CapabilityOperand<6400, 0, 0, [SPV_KHR_uniform_group_instructions], []>; defm USMStorageClassesINTEL : CapabilityOperand<5935, 0, 0, [SPV_INTEL_usm_storage_classes], [Kernel]>; defm BFloat16ConversionINTEL : CapabilityOperand<6115, 0, 0, [SPV_INTEL_bfloat16_conversion], []>; -defm TF32ConversionINTEL : CapabilityOperand<6425, 0, 0, [SPV_INTEL_tensor_float32_conversion], []>; defm GlobalVariableHostAccessINTEL : CapabilityOperand<6187, 0, 0, [SPV_INTEL_global_variable_host_access], []>; defm HostAccessINTEL : CapabilityOperand<6188, 0, 0, [SPV_INTEL_global_variable_host_access], []>; defm GlobalVariableFPGADecorationsINTEL : CapabilityOperand<6189, 0, 0, [SPV_INTEL_global_variable_fpga_decorations], []>; @@ -531,6 +530,7 @@ defm Subgroup2DBlockTransformINTEL : CapabilityOperand<6229, 0, 0, [SPV_INTEL_2d defm Subgroup2DBlockTransposeINTEL : CapabilityOperand<6230, 0, 0, [SPV_INTEL_2d_block_io], [Subgroup2DBlockIOINTEL]>; defm Int4TypeINTEL : CapabilityOperand<5112, 0, 0, [SPV_INTEL_int4], []>; defm Int4CooperativeMatrixINTEL : CapabilityOperand<5114, 0, 0, [SPV_INTEL_int4], [Int4TypeINTEL, CooperativeMatrixKHR]>; +defm TensorFloat32RoundingINTEL : CapabilityOperand<6425, 0, 0, [SPV_INTEL_tensor_float32_conversion], []>; //===----------------------------------------------------------------------===// // Multiclass used to define SourceLanguage enum values and at the same time diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_tensor_float32_conversion/tf32-conv-negative1.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_tensor_float32_conversion/tf32-conv-negative1.ll new file mode 100644 index 0000000000000..fa708ab022a85 --- /dev/null +++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_tensor_float32_conversion/tf32-conv-negative1.ll @@ -0,0 +1,12 @@ +; RUN: not llc -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_INTEL_tensor_float32_conversion %s -o %t.spvt 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR +; CHECK-ERROR: result and argument must have the same number of components + +target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64" +target triple = "spir64-unknown-unknown" + +define spir_func void @test(<8 x float> %in) { + %res = tail call spir_func float @_Z25__spirv_RoundFToTF32INTELDv8_f(<8 x float> %in) + ret void +} + +declare spir_func float @_Z25__spirv_RoundFToTF32INTELDv8_f(<8 x float>) diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_tensor_float32_conversion/tf32-conv-negative2.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_tensor_float32_conversion/tf32-conv-negative2.ll new file mode 100644 index 0000000000000..630b2fdd7696c --- /dev/null +++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_tensor_float32_conversion/tf32-conv-negative2.ll @@ -0,0 +1,12 @@ +; RUN: not llc -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_INTEL_tensor_float32_conversion %s -o %t.spvt 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR +; CHECK-ERROR: result and argument must have the same number of components + +target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64" +target triple = "spir64-unknown-unknown" + +define spir_func void @test(<8 x float> %in) { + %res = tail call spir_func <4 x float> @_Z25__spirv_RoundFToTF32INTELDv8_f(<8 x float> %in) + ret void +} + +declare spir_func <4 x float> @_Z25__spirv_RoundFToTF32INTELDv8_f(<8 x float>) diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_tensor_float32_conversion/tf32-conv.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_tensor_float32_conversion/tf32-conv.ll new file mode 100644 index 0000000000000..dcad78d17bff7 --- /dev/null +++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_tensor_float32_conversion/tf32-conv.ll @@ -0,0 +1,62 @@ +; RUN: llc -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_INTEL_tensor_float32_conversion %s -o - | FileCheck %s +; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_INTEL_tensor_float32_conversion %s -o - -filetype=obj | spirv-val %} + +; RUN: not llc -O0 -mtriple=spirv32-unknown-unknown %s -o %t.spvt 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR +; CHECK-ERROR: the builtin requires the following SPIR-V extension: SPV_INTEL_tensor_float32_conversion + +; CHECK: OpCapability TensorFloat32RoundingINTEL +; CHECK: OpExtension "SPV_INTEL_tensor_float32_conversion" + +; CHECK-DAG: %[[VoidTy:.*]] = OpTypeVoid +; CHECK-DAG: %[[FP32Ty:.*]] = OpTypeFloat 32 +; CHECK-DAG: %[[VecFloat2:.*]] = OpTypeVector %[[FP32Ty]] 2 +; CHECK-DAG: %[[VecFloat3:.*]] = OpTypeVector %[[FP32Ty]] 3 +; CHECK-DAG: %[[VecFloat4:.*]] = OpTypeVector %[[FP32Ty]] 4 +; CHECK-DAG: %[[VecFloat8:.*]] = OpTypeVector %[[FP32Ty]] 8 +; CHECK-DAG: %[[VecFloat16:.*]] = OpTypeVector %[[FP32Ty]] 16 +; CHECK-DAG: %[[FloatConstId:.*]] = OpConstant %[[FP32Ty]] 1.5 + +; CHECK: OpFunction %[[VoidTy]] +; CHECK: %[[FP32ValId:.*]] = OpFunctionParameter %[[FP32Ty]] +; CHECK: %[[FP32v8ValId:.*]] = OpFunctionParameter %[[VecFloat8]] +; CHECK: OpRoundFToTF32INTEL %[[FP32Ty]] %[[FP32ValId]] +; CHECK: OpRoundFToTF32INTEL %[[VecFloat8]] %[[FP32v8ValId]] +; CHECK: OpRoundFToTF32INTEL %[[FP32Ty]] %[[FloatConstId]] + +; CHECK: OpRoundFToTF32INTEL %[[FP32Ty]] +; CHECK: OpRoundFToTF32INTEL %[[VecFloat2]] +; CHECK: OpRoundFToTF32INTEL %[[VecFloat3]] +; CHECK: OpRoundFToTF32INTEL %[[VecFloat4]] +; CHECK: OpRoundFToTF32INTEL %[[VecFloat8]] +; CHECK: OpRoundFToTF32INTEL %[[VecFloat16]] + +target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64" +target triple = "spir64-unknown-unknown" + +define spir_func void @test(float %a, <8 x float> %in) { + %res1 = tail call spir_func float @_Z25__spirv_RoundFToTF32INTELf(float %a) + %res2 = tail call spir_func <8 x float> @_Z25__spirv_RoundFToTF32INTELDv8_f(<8 x float> %in) + %res3 = tail call spir_func float @_Z25__spirv_RoundFToTF32INTELf(float 1.500000e+00) + ret void +} + +declare spir_func float @_Z25__spirv_RoundFToTF32INTELf(float) +declare spir_func <8 x float> @_Z25__spirv_RoundFToTF32INTELDv8_f(<8 x float>) + +define dso_local spir_kernel void @test_ocl(float %a) { +entry: + %res4 = call spir_func float @_Z35intel_round_as_tensor_float32_floatt(float 0.000000e+00) + %res5 = call spir_func <2 x float> @_Z37intel_round_as_tensor_float322_float2Dv2_t(<2 x float> zeroinitializer) + %res6 = call spir_func <3 x float> @_Z37intel_round_as_tensor_float323_float3Dv3_t(<3 x float> zeroinitializer) + %res7 = call spir_func <4 x float> @_Z37intel_round_as_tensor_float324_float4Dv4_t(<4 x float> zeroinitializer) + %res8 = call spir_func <8 x float> @_Z37intel_round_as_tensor_float328_float8Dv8_t(<8 x float> zeroinitializer) + %res9 = call spir_func <16 x float> @_Z39intel_round_as_tensor_float3216_float16Dv16_t(<16 x float> zeroinitializer) + ret void +} + +declare spir_func float @_Z35intel_round_as_tensor_float32_floatt(float) +declare spir_func <2 x float> @_Z37intel_round_as_tensor_float322_float2Dv2_t(<2 x float>) +declare spir_func <3 x float> @_Z37intel_round_as_tensor_float323_float3Dv3_t(<3 x float>) +declare spir_func <4 x float> @_Z37intel_round_as_tensor_float324_float4Dv4_t(<4 x float>) +declare spir_func <8 x float> @_Z37intel_round_as_tensor_float328_float8Dv8_t(<8 x float>) +declare spir_func <16 x float> @_Z39intel_round_as_tensor_float3216_float16Dv16_t(<16 x float>) diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td index 90383265002a3..9c9eefd054fa6 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td @@ -405,6 +405,7 @@ def SPV_INTEL_memory_access_aliasing : I32EnumAttrCase<"SPV_INTEL_me def SPV_INTEL_split_barrier : I32EnumAttrCase<"SPV_INTEL_split_barrier", 4029>; def SPV_INTEL_bfloat16_conversion : I32EnumAttrCase<"SPV_INTEL_bfloat16_conversion", 4031>; def SPV_INTEL_cache_controls : I32EnumAttrCase<"SPV_INTEL_cache_controls", 4032>; +def SPV_INTEL_tensor_float32_conversion : I32EnumAttrCase<"SPV_INTEL_tensor_float32_conversion", 4033>; def SPV_NV_compute_shader_derivatives : I32EnumAttrCase<"SPV_NV_compute_shader_derivatives", 5000>; def SPV_NV_cooperative_matrix : I32EnumAttrCase<"SPV_NV_cooperative_matrix", 5001>; @@ -474,7 +475,8 @@ def SPIRV_ExtensionAttr : SPV_NV_shader_image_footprint, SPV_NV_shader_sm_builtins, SPV_NV_shader_subgroup_partitioned, SPV_NV_shading_rate, SPV_NV_stereo_view_rendering, SPV_NV_viewport_array2, SPV_NV_bindless_texture, - SPV_NV_ray_tracing_motion_blur, SPV_NVX_multiview_per_view_attributes + SPV_NV_ray_tracing_motion_blur, SPV_NVX_multiview_per_view_attributes, + SPV_INTEL_tensor_float32_conversion ]>; //===----------------------------------------------------------------------===// @@ -1465,6 +1467,12 @@ def SPIRV_C_Bfloat16ConversionINTEL : I32EnumAttrCase<"B ]; } +def SPIRV_C_TensorFloat32RoundingINTEL : I32EnumAttrCase<"TensorFloat32RoundingINTEL", 6425> { + list availability = [ + Extension<[SPV_INTEL_tensor_float32_conversion]> + ]; +} + def SPIRV_C_CacheControlsINTEL : I32EnumAttrCase<"CacheControlsINTEL", 6441> { list availability = [ Extension<[SPV_INTEL_cache_controls]> @@ -1567,7 +1575,8 @@ def SPIRV_CapabilityAttr : SPIRV_C_ShaderViewportIndexLayerEXT, SPIRV_C_ShaderViewportMaskNV, SPIRV_C_ShaderStereoViewNV, SPIRV_C_Bfloat16ConversionINTEL, SPIRV_C_CacheControlsINTEL, SPIRV_C_BFloat16TypeKHR, - SPIRV_C_BFloat16DotProductKHR, SPIRV_C_BFloat16CooperativeMatrixKHR + SPIRV_C_BFloat16DotProductKHR, SPIRV_C_BFloat16CooperativeMatrixKHR, + SPIRV_C_TensorFloat32RoundingINTEL ]>; def SPIRV_AM_Logical : I32EnumAttrCase<"Logical", 0>; @@ -4586,6 +4595,7 @@ def SPIRV_OC_OpControlBarrierArriveINTEL : I32EnumAttrCase<"OpControlBarrie def SPIRV_OC_OpControlBarrierWaitINTEL : I32EnumAttrCase<"OpControlBarrierWaitINTEL", 6143>; def SPIRV_OC_OpGroupIMulKHR : I32EnumAttrCase<"OpGroupIMulKHR", 6401>; def SPIRV_OC_OpGroupFMulKHR : I32EnumAttrCase<"OpGroupFMulKHR", 6402>; +def SPIRV_OC_OpRoundFToTF32INTEL : I32EnumAttrCase<"OpRoundFToTF32INTEL", 6426>; def SPIRV_OpcodeAttr : SPIRV_I32EnumAttr<"Opcode", "valid SPIR-V instructions", "opcode", [ @@ -4690,7 +4700,8 @@ def SPIRV_OpcodeAttr : SPIRV_OC_OpAssumeTrueKHR, SPIRV_OC_OpAtomicFAddEXT, SPIRV_OC_OpConvertFToBF16INTEL, SPIRV_OC_OpConvertBF16ToFINTEL, SPIRV_OC_OpControlBarrierArriveINTEL, SPIRV_OC_OpControlBarrierWaitINTEL, - SPIRV_OC_OpGroupIMulKHR, SPIRV_OC_OpGroupFMulKHR + SPIRV_OC_OpGroupIMulKHR, SPIRV_OC_OpGroupFMulKHR, + SPIRV_OC_OpRoundFToTF32INTEL ]>; // End opcode section. Generated from SPIR-V spec; DO NOT MODIFY! diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td index 82d26e365fb24..b692c07122683 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td @@ -11,6 +11,7 @@ // at (https://github.com/intel/llvm) // Supported extensions // * SPV_INTEL_bfloat16_conversion +// * SPV_INTEL_tensor_float32_conversion //===----------------------------------------------------------------------===// @@ -110,6 +111,59 @@ def SPIRV_INTELConvertBF16ToFOp : SPIRV_IntelVendorOp<"ConvertBF16ToF", []> { let hasVerifier = 1; } +// ----- + +def SPIRV_INTELRoundFToTF32Op : SPIRV_IntelVendorOp<"RoundFToTF32", []> { + let summary = "See extension SPV_INTEL_tensor_float32_conversion"; + + let description = [{ + Convert value numerically from a 32-bit floating point type to tensor float32, + with rounding to the nearest even. + + Result Type must be a scalar or vector of 32-bit floating-point type. + The component width must be 32 bits. Bit pattern in the Result represents a tensor float32 value. + + Float Value must be a scalar or vector of floating-point type. + It must have the same number of components as Result Type. The component width must be 32 bits. + + Results are computed per component. + + + ``` + convert-f-to-tf32-op ::= ssa-id `=` `spirv.INTEL.RoundFToTF32` ssa-use + `:` operand-type `to` result-type + ``` + + #### Example: + + ```mlir + %1 = spirv.RoundFToTF32 %0 : f32 to f32 + %3 = spirv.RoundFToTF32 %2 : vector<3xf32> to vector<3xf32> + ``` + + }]; + + + let availability = [ + MinVersion, + MaxVersion, + Extension<[SPV_INTEL_tensor_float32_conversion]>, + Capability<[SPIRV_C_TensorFloat32RoundingINTEL]> + ]; + + let arguments = (ins + SPIRV_ScalarOrVectorOf:$operand + ); + + let results = (outs + SPIRV_ScalarOrVectorOf:$result + ); + let assemblyFormat = [{ + $operand attr-dict `:` type($operand) `to` type($result) + }]; + + let hasVerifier = 1; +} // ----- diff --git a/mlir/lib/Dialect/SPIRV/IR/CastOps.cpp b/mlir/lib/Dialect/SPIRV/IR/CastOps.cpp index e27dc274673be..fc3e7308356bf 100644 --- a/mlir/lib/Dialect/SPIRV/IR/CastOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/CastOps.cpp @@ -311,6 +311,27 @@ LogicalResult INTELConvertFToBF16Op::verify() { return success(); } +//===----------------------------------------------------------------------===// +// spirv.INTELRoundFToTF32Op +//===----------------------------------------------------------------------===// + +LogicalResult INTELRoundFToTF32Op::verify() { + auto operandType = getOperand().getType(); + auto resultType = getResult().getType(); + // ODS checks that vector result type and vector operand type have the same + // shape. + if (auto vectorType = llvm::dyn_cast(operandType)) { + unsigned operandNumElements = vectorType.getNumElements(); + unsigned resultNumElements = + llvm::cast(resultType).getNumElements(); + if (operandNumElements != resultNumElements) { + return emitOpError( + "operand and result must have same number of elements"); + } + } + return success(); +} + //===----------------------------------------------------------------------===// // spirv.FConvertOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir b/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir index bb15d018a6c44..aa5bee5796cfa 100644 --- a/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir @@ -72,6 +72,42 @@ spirv.func @bf16_to_f32_vec_unsupported(%arg0 : vector<2xi16>) "None" { // ----- +//===----------------------------------------------------------------------===// +// spirv.INTEL.RoundFToTF32 +//===----------------------------------------------------------------------===// + +spirv.func @f32_to_tf32(%arg0 : f32) "None" { + // CHECK: {{%.*}} = spirv.INTEL.RoundFToTF32 {{%.*}} : f32 to f32 + %0 = spirv.INTEL.RoundFToTF32 %arg0 : f32 to f32 + spirv.Return +} + +// ----- + +spirv.func @f32_to_tf32_vec(%arg0 : vector<2xf32>) "None" { + // CHECK: {{%.*}} = spirv.INTEL.RoundFToTF32 {{%.*}} : vector<2xf32> to vector<2xf32> + %0 = spirv.INTEL.RoundFToTF32 %arg0 : vector<2xf32> to vector<2xf32> + spirv.Return +} + +// ----- + +spirv.func @f32_to_tf32_unsupported(%arg0 : f64) "None" { + // expected-error @+1 {{operand #0 must be Float32 or vector of Float32 values of length 2/3/4/8/16, but got}} + %0 = spirv.INTEL.RoundFToTF32 %arg0 : f64 to f32 + spirv.Return +} + +// ----- + +spirv.func @f32_to_tf32_vec_unsupported(%arg0 : vector<2xf32>) "None" { + // expected-error @+1 {{operand and result must have same number of elements}} + %0 = spirv.INTEL.RoundFToTF32 %arg0 : vector<2xf32> to vector<4xf32> + spirv.Return +} + +// ----- + //===----------------------------------------------------------------------===// // spirv.INTEL.SplitBarrier //===----------------------------------------------------------------------===// diff --git a/mlir/test/Target/SPIRV/intel-ext-ops.mlir b/mlir/test/Target/SPIRV/intel-ext-ops.mlir index 6d2fd324363c6..53cf8bf8fbd62 100644 --- a/mlir/test/Target/SPIRV/intel-ext-ops.mlir +++ b/mlir/test/Target/SPIRV/intel-ext-ops.mlir @@ -32,6 +32,28 @@ spirv.module Logical GLSL450 requires #spirv.vce { + // CHECK-LABEL: @f32_to_tf32 + spirv.func @f32_to_tf32(%arg0 : f32) "None" { + // CHECK: {{%.*}} = spirv.INTEL.RoundFToTF32 {{%.*}} : f32 to f32 + %1 = spirv.INTEL.RoundFToTF32 %arg0 : f32 to f32 + spirv.Return + } + + // CHECK-LABEL: @f32_to_tf32_vec + spirv.func @f32_to_tf32_vec(%arg0 : vector<2xf32>) "None" { + // CHECK: {{%.*}} = spirv.INTEL.RoundFToTF32 {{%.*}} : vector<2xf32> to vector<2xf32> + %1 = spirv.INTEL.RoundFToTF32 %arg0 : vector<2xf32> to vector<2xf32> + spirv.Return + } +} + +// ----- + //===----------------------------------------------------------------------===// // spirv.INTEL.SplitBarrier //===----------------------------------------------------------------------===// From 6135f1414ed426c510c9c626ccc6bc876e2d0e22 Mon Sep 17 00:00:00 2001 From: "Zhang, Yixing" Date: Tue, 29 Jul 2025 19:18:46 -0700 Subject: [PATCH 3/6] fix clang format --- llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp | 9 +++++---- llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp | 2 +- llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp | 3 ++- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp index 03ca2ad1d8fa5..a3c9cb96d013b 100644 --- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp @@ -699,7 +699,8 @@ static bool buildAtomicStoreInst(const SPIRV::IncomingCall *Call, MachineIRBuilder &MIRBuilder, SPIRVGlobalRegistry *GR) { if (Call->isSpirvOp()) - return buildOpFromWrapper(MIRBuilder, SPIRV::OpAtomicStore, Call, Register(0)); + return buildOpFromWrapper(MIRBuilder, SPIRV::OpAtomicStore, Call, + Register(0)); Register ScopeRegister = buildConstantIntReg32(SPIRV::Scope::Device, MIRBuilder, GR); @@ -2679,13 +2680,13 @@ static bool generateConvertInst(const StringRef DemangledCall, } } else if (GR->isScalarOrVectorOfType(Call->ReturnRegister, SPIRV::OpTypeFloat)) { - if(Builtin->IsTF32){ + if (Builtin->IsTF32) { const auto *ST = static_cast( - &MIRBuilder.getMF().getSubtarget()); + &MIRBuilder.getMF().getSubtarget()); if (!ST->canUseExtension( SPIRV::Extension::SPV_INTEL_tensor_float32_conversion)) NeedExtMsg = "SPV_INTEL_tensor_float32_conversion"; - IsRightComponentsNumber = + IsRightComponentsNumber = GR->getScalarOrVectorComponentCount(Call->Arguments[0]) == GR->getScalarOrVectorComponentCount(Call->ReturnRegister); Opcode = SPIRV::OpRoundFToTF32INTEL; diff --git a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp index 945d3febe0bcf..d9265f498973e 100644 --- a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp @@ -103,7 +103,7 @@ static const std::map> {"SPV_INTEL_int4", SPIRV::Extension::Extension::SPV_INTEL_int4}, {"SPV_KHR_float_controls2", SPIRV::Extension::Extension::SPV_KHR_float_controls2}, - {"SPV_INTEL_tensor_float32_conversion", + {"SPV_INTEL_tensor_float32_conversion", SPIRV::Extension::Extension::SPV_INTEL_tensor_float32_conversion}}; bool SPIRVExtensionsParser::parse(cl::Option &O, StringRef ArgName, diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp index eac337c3c4246..0cd9d7882a52a 100644 --- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp @@ -1565,7 +1565,8 @@ void addInstrRequirements(const MachineInstr &MI, } break; case SPIRV::OpRoundFToTF32INTEL: - if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_tensor_float32_conversion)) { + if (ST.canUseExtension( + SPIRV::Extension::SPV_INTEL_tensor_float32_conversion)) { Reqs.addExtension(SPIRV::Extension::SPV_INTEL_tensor_float32_conversion); Reqs.addCapability(SPIRV::Capability::TensorFloat32RoundingINTEL); } From 8528c2f9f755b14af347a1f46274014bf7749388 Mon Sep 17 00:00:00 2001 From: "Zhang, Yixing" Date: Tue, 29 Jul 2025 20:07:24 -0700 Subject: [PATCH 4/6] fix the CI test failure --- llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp index a3c9cb96d013b..25cdf72a658a8 100644 --- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp @@ -244,7 +244,7 @@ std::string lookupBuiltinNameHelper(StringRef DemangledCall, "SDotKHR|SUDotKHR|SDotAccSatKHR|UDotAccSatKHR|SUDotAccSatKHR|" "ReadClockKHR|SubgroupBlockReadINTEL|SubgroupImageBlockReadINTEL|" "SubgroupImageMediaBlockReadINTEL|SubgroupImageMediaBlockWriteINTEL|" - "Convert|Round" + "Convert|Round|" "UConvert|SConvert|FConvert|SatConvert)[^_]*)(_R[^_]*_?(\\w+)?.*)?"); std::smatch Match; if (std::regex_match(BuiltinName, Match, SpvWithR) && Match.size() > 1) { From 5b4b793154762cda696b9d301aaf64ef56757e2c Mon Sep 17 00:00:00 2001 From: "Zhang, Yixing" Date: Wed, 30 Jul 2025 06:55:50 -0700 Subject: [PATCH 5/6] move the mlir part to a new PR --- .../mlir/Dialect/SPIRV/IR/SPIRVBase.td | 17 ++---- .../mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td | 54 ------------------- mlir/lib/Dialect/SPIRV/IR/CastOps.cpp | 21 -------- mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir | 36 ------------- mlir/test/Target/SPIRV/intel-ext-ops.mlir | 22 -------- 5 files changed, 3 insertions(+), 147 deletions(-) diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td index 9c9eefd054fa6..90383265002a3 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td @@ -405,7 +405,6 @@ def SPV_INTEL_memory_access_aliasing : I32EnumAttrCase<"SPV_INTEL_me def SPV_INTEL_split_barrier : I32EnumAttrCase<"SPV_INTEL_split_barrier", 4029>; def SPV_INTEL_bfloat16_conversion : I32EnumAttrCase<"SPV_INTEL_bfloat16_conversion", 4031>; def SPV_INTEL_cache_controls : I32EnumAttrCase<"SPV_INTEL_cache_controls", 4032>; -def SPV_INTEL_tensor_float32_conversion : I32EnumAttrCase<"SPV_INTEL_tensor_float32_conversion", 4033>; def SPV_NV_compute_shader_derivatives : I32EnumAttrCase<"SPV_NV_compute_shader_derivatives", 5000>; def SPV_NV_cooperative_matrix : I32EnumAttrCase<"SPV_NV_cooperative_matrix", 5001>; @@ -475,8 +474,7 @@ def SPIRV_ExtensionAttr : SPV_NV_shader_image_footprint, SPV_NV_shader_sm_builtins, SPV_NV_shader_subgroup_partitioned, SPV_NV_shading_rate, SPV_NV_stereo_view_rendering, SPV_NV_viewport_array2, SPV_NV_bindless_texture, - SPV_NV_ray_tracing_motion_blur, SPV_NVX_multiview_per_view_attributes, - SPV_INTEL_tensor_float32_conversion + SPV_NV_ray_tracing_motion_blur, SPV_NVX_multiview_per_view_attributes ]>; //===----------------------------------------------------------------------===// @@ -1467,12 +1465,6 @@ def SPIRV_C_Bfloat16ConversionINTEL : I32EnumAttrCase<"B ]; } -def SPIRV_C_TensorFloat32RoundingINTEL : I32EnumAttrCase<"TensorFloat32RoundingINTEL", 6425> { - list availability = [ - Extension<[SPV_INTEL_tensor_float32_conversion]> - ]; -} - def SPIRV_C_CacheControlsINTEL : I32EnumAttrCase<"CacheControlsINTEL", 6441> { list availability = [ Extension<[SPV_INTEL_cache_controls]> @@ -1575,8 +1567,7 @@ def SPIRV_CapabilityAttr : SPIRV_C_ShaderViewportIndexLayerEXT, SPIRV_C_ShaderViewportMaskNV, SPIRV_C_ShaderStereoViewNV, SPIRV_C_Bfloat16ConversionINTEL, SPIRV_C_CacheControlsINTEL, SPIRV_C_BFloat16TypeKHR, - SPIRV_C_BFloat16DotProductKHR, SPIRV_C_BFloat16CooperativeMatrixKHR, - SPIRV_C_TensorFloat32RoundingINTEL + SPIRV_C_BFloat16DotProductKHR, SPIRV_C_BFloat16CooperativeMatrixKHR ]>; def SPIRV_AM_Logical : I32EnumAttrCase<"Logical", 0>; @@ -4595,7 +4586,6 @@ def SPIRV_OC_OpControlBarrierArriveINTEL : I32EnumAttrCase<"OpControlBarrie def SPIRV_OC_OpControlBarrierWaitINTEL : I32EnumAttrCase<"OpControlBarrierWaitINTEL", 6143>; def SPIRV_OC_OpGroupIMulKHR : I32EnumAttrCase<"OpGroupIMulKHR", 6401>; def SPIRV_OC_OpGroupFMulKHR : I32EnumAttrCase<"OpGroupFMulKHR", 6402>; -def SPIRV_OC_OpRoundFToTF32INTEL : I32EnumAttrCase<"OpRoundFToTF32INTEL", 6426>; def SPIRV_OpcodeAttr : SPIRV_I32EnumAttr<"Opcode", "valid SPIR-V instructions", "opcode", [ @@ -4700,8 +4690,7 @@ def SPIRV_OpcodeAttr : SPIRV_OC_OpAssumeTrueKHR, SPIRV_OC_OpAtomicFAddEXT, SPIRV_OC_OpConvertFToBF16INTEL, SPIRV_OC_OpConvertBF16ToFINTEL, SPIRV_OC_OpControlBarrierArriveINTEL, SPIRV_OC_OpControlBarrierWaitINTEL, - SPIRV_OC_OpGroupIMulKHR, SPIRV_OC_OpGroupFMulKHR, - SPIRV_OC_OpRoundFToTF32INTEL + SPIRV_OC_OpGroupIMulKHR, SPIRV_OC_OpGroupFMulKHR ]>; // End opcode section. Generated from SPIR-V spec; DO NOT MODIFY! diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td index b692c07122683..82d26e365fb24 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVIntelExtOps.td @@ -11,7 +11,6 @@ // at (https://github.com/intel/llvm) // Supported extensions // * SPV_INTEL_bfloat16_conversion -// * SPV_INTEL_tensor_float32_conversion //===----------------------------------------------------------------------===// @@ -111,59 +110,6 @@ def SPIRV_INTELConvertBF16ToFOp : SPIRV_IntelVendorOp<"ConvertBF16ToF", []> { let hasVerifier = 1; } -// ----- - -def SPIRV_INTELRoundFToTF32Op : SPIRV_IntelVendorOp<"RoundFToTF32", []> { - let summary = "See extension SPV_INTEL_tensor_float32_conversion"; - - let description = [{ - Convert value numerically from a 32-bit floating point type to tensor float32, - with rounding to the nearest even. - - Result Type must be a scalar or vector of 32-bit floating-point type. - The component width must be 32 bits. Bit pattern in the Result represents a tensor float32 value. - - Float Value must be a scalar or vector of floating-point type. - It must have the same number of components as Result Type. The component width must be 32 bits. - - Results are computed per component. - - - ``` - convert-f-to-tf32-op ::= ssa-id `=` `spirv.INTEL.RoundFToTF32` ssa-use - `:` operand-type `to` result-type - ``` - - #### Example: - - ```mlir - %1 = spirv.RoundFToTF32 %0 : f32 to f32 - %3 = spirv.RoundFToTF32 %2 : vector<3xf32> to vector<3xf32> - ``` - - }]; - - - let availability = [ - MinVersion, - MaxVersion, - Extension<[SPV_INTEL_tensor_float32_conversion]>, - Capability<[SPIRV_C_TensorFloat32RoundingINTEL]> - ]; - - let arguments = (ins - SPIRV_ScalarOrVectorOf:$operand - ); - - let results = (outs - SPIRV_ScalarOrVectorOf:$result - ); - let assemblyFormat = [{ - $operand attr-dict `:` type($operand) `to` type($result) - }]; - - let hasVerifier = 1; -} // ----- diff --git a/mlir/lib/Dialect/SPIRV/IR/CastOps.cpp b/mlir/lib/Dialect/SPIRV/IR/CastOps.cpp index fc3e7308356bf..e27dc274673be 100644 --- a/mlir/lib/Dialect/SPIRV/IR/CastOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/CastOps.cpp @@ -311,27 +311,6 @@ LogicalResult INTELConvertFToBF16Op::verify() { return success(); } -//===----------------------------------------------------------------------===// -// spirv.INTELRoundFToTF32Op -//===----------------------------------------------------------------------===// - -LogicalResult INTELRoundFToTF32Op::verify() { - auto operandType = getOperand().getType(); - auto resultType = getResult().getType(); - // ODS checks that vector result type and vector operand type have the same - // shape. - if (auto vectorType = llvm::dyn_cast(operandType)) { - unsigned operandNumElements = vectorType.getNumElements(); - unsigned resultNumElements = - llvm::cast(resultType).getNumElements(); - if (operandNumElements != resultNumElements) { - return emitOpError( - "operand and result must have same number of elements"); - } - } - return success(); -} - //===----------------------------------------------------------------------===// // spirv.FConvertOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir b/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir index aa5bee5796cfa..bb15d018a6c44 100644 --- a/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir @@ -72,42 +72,6 @@ spirv.func @bf16_to_f32_vec_unsupported(%arg0 : vector<2xi16>) "None" { // ----- -//===----------------------------------------------------------------------===// -// spirv.INTEL.RoundFToTF32 -//===----------------------------------------------------------------------===// - -spirv.func @f32_to_tf32(%arg0 : f32) "None" { - // CHECK: {{%.*}} = spirv.INTEL.RoundFToTF32 {{%.*}} : f32 to f32 - %0 = spirv.INTEL.RoundFToTF32 %arg0 : f32 to f32 - spirv.Return -} - -// ----- - -spirv.func @f32_to_tf32_vec(%arg0 : vector<2xf32>) "None" { - // CHECK: {{%.*}} = spirv.INTEL.RoundFToTF32 {{%.*}} : vector<2xf32> to vector<2xf32> - %0 = spirv.INTEL.RoundFToTF32 %arg0 : vector<2xf32> to vector<2xf32> - spirv.Return -} - -// ----- - -spirv.func @f32_to_tf32_unsupported(%arg0 : f64) "None" { - // expected-error @+1 {{operand #0 must be Float32 or vector of Float32 values of length 2/3/4/8/16, but got}} - %0 = spirv.INTEL.RoundFToTF32 %arg0 : f64 to f32 - spirv.Return -} - -// ----- - -spirv.func @f32_to_tf32_vec_unsupported(%arg0 : vector<2xf32>) "None" { - // expected-error @+1 {{operand and result must have same number of elements}} - %0 = spirv.INTEL.RoundFToTF32 %arg0 : vector<2xf32> to vector<4xf32> - spirv.Return -} - -// ----- - //===----------------------------------------------------------------------===// // spirv.INTEL.SplitBarrier //===----------------------------------------------------------------------===// diff --git a/mlir/test/Target/SPIRV/intel-ext-ops.mlir b/mlir/test/Target/SPIRV/intel-ext-ops.mlir index 53cf8bf8fbd62..6d2fd324363c6 100644 --- a/mlir/test/Target/SPIRV/intel-ext-ops.mlir +++ b/mlir/test/Target/SPIRV/intel-ext-ops.mlir @@ -32,28 +32,6 @@ spirv.module Logical GLSL450 requires #spirv.vce { - // CHECK-LABEL: @f32_to_tf32 - spirv.func @f32_to_tf32(%arg0 : f32) "None" { - // CHECK: {{%.*}} = spirv.INTEL.RoundFToTF32 {{%.*}} : f32 to f32 - %1 = spirv.INTEL.RoundFToTF32 %arg0 : f32 to f32 - spirv.Return - } - - // CHECK-LABEL: @f32_to_tf32_vec - spirv.func @f32_to_tf32_vec(%arg0 : vector<2xf32>) "None" { - // CHECK: {{%.*}} = spirv.INTEL.RoundFToTF32 {{%.*}} : vector<2xf32> to vector<2xf32> - %1 = spirv.INTEL.RoundFToTF32 %arg0 : vector<2xf32> to vector<2xf32> - spirv.Return - } -} - -// ----- - //===----------------------------------------------------------------------===// // spirv.INTEL.SplitBarrier //===----------------------------------------------------------------------===// From 608c8fae2d23b51608b6e43905b9db9880a46a20 Mon Sep 17 00:00:00 2001 From: "Zhang, Yixing" Date: Wed, 30 Jul 2025 08:44:45 -0700 Subject: [PATCH 6/6] solve the grammar mistake --- llvm/lib/Target/SPIRV/SPIRVBuiltins.td | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.td b/llvm/lib/Target/SPIRV/SPIRVBuiltins.td index 49d11bf7c8dca..d08560bb6565a 100644 --- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.td +++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.td @@ -1559,8 +1559,8 @@ foreach conv = ["FToBF16INTEL", "BF16ToFINTEL"] in { } // cl_intel_tensor_float32_conversions / SPV_INTEL_tensor_float32_conversion -// Multiclass used to define at the same time both a demangled builtin records -// and a corresponding convert builtin records. +// Multiclass used to define at the same time both a demangled builtin record +// and a corresponding convert builtin record. multiclass DemangledTF32RoundBuiltin { // Create records for scalar and vector conversions. foreach i = ["", "2", "3", "4", "8", "16"] in {