Skip to content

Commit b90fd0c

Browse files
add tests, finalize the implementation and code cleanup
1 parent 3749aa6 commit b90fd0c

File tree

14 files changed

+255
-20
lines changed

14 files changed

+255
-20
lines changed

llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,7 @@ std::string lookupBuiltinNameHelper(StringRef DemangledCall,
231231
// - "__spirv_SubgroupImageMediaBlockReadINTEL"
232232
// - "__spirv_SubgroupImageMediaBlockWriteINTEL"
233233
// - "__spirv_Convert"
234+
// - "__spirv_Round"
234235
// - "__spirv_UConvert"
235236
// - "__spirv_SConvert"
236237
// - "__spirv_FConvert"
@@ -243,7 +244,7 @@ std::string lookupBuiltinNameHelper(StringRef DemangledCall,
243244
"SDotKHR|SUDotKHR|SDotAccSatKHR|UDotAccSatKHR|SUDotAccSatKHR|"
244245
"ReadClockKHR|SubgroupBlockReadINTEL|SubgroupImageBlockReadINTEL|"
245246
"SubgroupImageMediaBlockReadINTEL|SubgroupImageMediaBlockWriteINTEL|"
246-
"Convert|"
247+
"Convert|Round"
247248
"UConvert|SConvert|FConvert|SatConvert)[^_]*)(_R[^_]*_?(\\w+)?.*)?");
248249
std::smatch Match;
249250
if (std::regex_match(BuiltinName, Match, SpvWithR) && Match.size() > 1) {
@@ -2682,14 +2683,14 @@ static bool generateConvertInst(const StringRef DemangledCall,
26822683
const auto *ST = static_cast<const SPIRVSubtarget *>(
26832684
&MIRBuilder.getMF().getSubtarget());
26842685
if (!ST->canUseExtension(
2685-
SPIRV::Extension::SPV_INTEL_bfloat16_conversion))
2686-
NeedExtMsg = "SPV_INTEL_bfloat16_conversion";
2686+
SPIRV::Extension::SPV_INTEL_tensor_float32_conversion))
2687+
NeedExtMsg = "SPV_INTEL_tensor_float32_conversion";
26872688
IsRightComponentsNumber =
26882689
GR->getScalarOrVectorComponentCount(Call->Arguments[0]) ==
26892690
GR->getScalarOrVectorComponentCount(Call->ReturnRegister);
26902691
Opcode = SPIRV::OpRoundFToTF32INTEL;
26912692
} else {
2692-
Float -> Float
2693+
// Float -> Float
26932694
Opcode = SPIRV::OpFConvert;
26942695
}
26952696
}

llvm/lib/Target/SPIRV/SPIRVBuiltins.td

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1461,7 +1461,8 @@ class ConvertBuiltin<string name, InstructionSet set> {
14611461
bit IsRounded = !not(!eq(!find(name, "_rt"), -1));
14621462
bit IsBfloat16 = !or(!not(!eq(!find(name, "BF16"), -1)),
14631463
!not(!eq(!find(name, "bfloat16"), -1)));
1464-
bit IsTF32 = !not(!eq(!find(name, "TF32"), -1));
1464+
bit IsTF32 = !or(!not(!eq(!find(name, "TF32"), -1)),
1465+
!not(!eq(!find(name, "tensor_float32"), -1)));
14651466
FPRoundingMode RoundingMode = !cond(!not(!eq(!find(name, "_rte"), -1)) : RTE,
14661467
!not(!eq(!find(name, "_rtz"), -1)) : RTZ,
14671468
!not(!eq(!find(name, "_rtp"), -1)) : RTP,
@@ -1557,22 +1558,23 @@ foreach conv = ["FToBF16INTEL", "BF16ToFINTEL"] in {
15571558
def : ConvertBuiltin<!strconcat("__spirv_Convert", conv), OpenCL_std>;
15581559
}
15591560

1560-
// SPV_INTEL_tensor_float32_conversion
1561+
// cl_intel_tensor_float32_conversions / SPV_INTEL_tensor_float32_conversion
15611562
// Multiclass used to define at the same time both a demangled builtin records
15621563
// and a corresponding convert builtin records.
1563-
multiclass DemangledTF32ConvertBuiltin<string name1, string name2> {
1564+
multiclass DemangledTF32RoundBuiltin<string name1, string name2> {
15641565
// Create records for scalar and vector conversions.
15651566
foreach i = ["", "2", "3", "4", "8", "16"] in {
1566-
def : DemangledBuiltin<!strconcat("intel_convert_", name1, i, name2, i), OpenCL_std, Convert, 1, 1>;
1567-
def : ConvertBuiltin<!strconcat("intel_convert_", name1, i, name2, i), OpenCL_std>;
1567+
def : DemangledBuiltin<!strconcat("intel_round_", name1, i, name2, i), OpenCL_std, Convert, 1, 1>;
1568+
def : ConvertBuiltin<!strconcat("intel_round_", name1, i, name2, i), OpenCL_std>;
15681569
}
15691570
}
15701571

1571-
defm : DemangledTF32ConvertBuiltin<"ConvertFToTF32INTEL">;
1572+
defm : DemangledTF32RoundBuiltin<"tensor_float32", "_as_float">;
1573+
defm : DemangledTF32RoundBuiltin<"as_tensor_float32", "_float">;
15721574

15731575
foreach conv = ["FToTF32INTEL"] in {
1574-
def : DemangledBuiltin<!strconcat("__spirv_Convert", conv), OpenCL_std, Convert, 1, 1>;
1575-
def : ConvertBuiltin<!strconcat("__spirv_Convert", conv), OpenCL_std>;
1576+
def : DemangledBuiltin<!strconcat("__spirv_Round", conv), OpenCL_std, Convert, 1, 1>;
1577+
def : ConvertBuiltin<!strconcat("__spirv_Round", conv), OpenCL_std>;
15761578
}
15771579

15781580
//===----------------------------------------------------------------------===//

llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,9 @@ static const std::map<std::string, SPIRV::Extension::Extension, std::less<>>
102102
SPIRV::Extension::Extension::SPV_INTEL_2d_block_io},
103103
{"SPV_INTEL_int4", SPIRV::Extension::Extension::SPV_INTEL_int4},
104104
{"SPV_KHR_float_controls2",
105-
SPIRV::Extension::Extension::SPV_KHR_float_controls2}};
105+
SPIRV::Extension::Extension::SPV_KHR_float_controls2},
106+
{"SPV_INTEL_tensor_float32_conversion",
107+
SPIRV::Extension::Extension::SPV_INTEL_tensor_float32_conversion}};
106108

107109
bool SPIRVExtensionsParser::parse(cl::Option &O, StringRef ArgName,
108110
StringRef ArgValue,

llvm/lib/Target/SPIRV/SPIRVInstrInfo.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -441,11 +441,11 @@ def OpBitcast : UnOp<"OpBitcast", 124>;
441441
def OpPtrCastToCrossWorkgroupINTEL : UnOp<"OpPtrCastToCrossWorkgroupINTEL", 5934>;
442442
def OpCrossWorkgroupCastToPtrINTEL : UnOp<"OpCrossWorkgroupCastToPtrINTEL", 5938>;
443443

444-
// SPV_INTEL_tensor_float32_conversion
444+
// SPV_INTEL_bfloat16_conversion
445445
def OpConvertFToBF16INTEL : UnOp<"OpConvertFToBF16INTEL", 6116>;
446446
def OpConvertBF16ToFINTEL : UnOp<"OpConvertBF16ToFINTEL", 6117>;
447447

448-
// SPV_INTEL_bfloat16_conversion
448+
// SPV_INTEL_tensor_float32_conversion
449449
def OpRoundFToTF32INTEL : UnOp<"OpRoundFToTF32INTEL", 6426>;
450450

451451
// 3.42.12 Composite Instructions

llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1567,7 +1567,7 @@ void addInstrRequirements(const MachineInstr &MI,
15671567
case SPIRV::OpRoundFToTF32INTEL:
15681568
if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_tensor_float32_conversion)) {
15691569
Reqs.addExtension(SPIRV::Extension::SPV_INTEL_tensor_float32_conversion);
1570-
Reqs.addCapability(SPIRV::Capability::TF32ConversionINTEL);
1570+
Reqs.addCapability(SPIRV::Capability::TensorFloat32RoundingINTEL);
15711571
}
15721572
break;
15731573
case SPIRV::OpVariableLengthArrayINTEL:

llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -503,7 +503,6 @@ defm VariableLengthArrayINTEL : CapabilityOperand<5817, 0, 0, [SPV_INTEL_variabl
503503
defm GroupUniformArithmeticKHR : CapabilityOperand<6400, 0, 0, [SPV_KHR_uniform_group_instructions], []>;
504504
defm USMStorageClassesINTEL : CapabilityOperand<5935, 0, 0, [SPV_INTEL_usm_storage_classes], [Kernel]>;
505505
defm BFloat16ConversionINTEL : CapabilityOperand<6115, 0, 0, [SPV_INTEL_bfloat16_conversion], []>;
506-
defm TF32ConversionINTEL : CapabilityOperand<6425, 0, 0, [SPV_INTEL_tensor_float32_conversion], []>;
507506
defm GlobalVariableHostAccessINTEL : CapabilityOperand<6187, 0, 0, [SPV_INTEL_global_variable_host_access], []>;
508507
defm HostAccessINTEL : CapabilityOperand<6188, 0, 0, [SPV_INTEL_global_variable_host_access], []>;
509508
defm GlobalVariableFPGADecorationsINTEL : CapabilityOperand<6189, 0, 0, [SPV_INTEL_global_variable_fpga_decorations], []>;
@@ -531,6 +530,7 @@ defm Subgroup2DBlockTransformINTEL : CapabilityOperand<6229, 0, 0, [SPV_INTEL_2d
531530
defm Subgroup2DBlockTransposeINTEL : CapabilityOperand<6230, 0, 0, [SPV_INTEL_2d_block_io], [Subgroup2DBlockIOINTEL]>;
532531
defm Int4TypeINTEL : CapabilityOperand<5112, 0, 0, [SPV_INTEL_int4], []>;
533532
defm Int4CooperativeMatrixINTEL : CapabilityOperand<5114, 0, 0, [SPV_INTEL_int4], [Int4TypeINTEL, CooperativeMatrixKHR]>;
533+
defm TensorFloat32RoundingINTEL : CapabilityOperand<6425, 0, 0, [SPV_INTEL_tensor_float32_conversion], []>;
534534

535535
//===----------------------------------------------------------------------===//
536536
// Multiclass used to define SourceLanguage enum values and at the same time
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
; RUN: not llc -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_INTEL_tensor_float32_conversion %s -o %t.spvt 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR
2+
; CHECK-ERROR: result and argument must have the same number of components
3+
4+
target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
5+
target triple = "spir64-unknown-unknown"
6+
7+
define spir_func void @test(<8 x float> %in) {
8+
%res = tail call spir_func float @_Z25__spirv_RoundFToTF32INTELDv8_f(<8 x float> %in)
9+
ret void
10+
}
11+
12+
declare spir_func float @_Z25__spirv_RoundFToTF32INTELDv8_f(<8 x float>)
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
; RUN: not llc -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_INTEL_tensor_float32_conversion %s -o %t.spvt 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR
2+
; CHECK-ERROR: result and argument must have the same number of components
3+
4+
target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
5+
target triple = "spir64-unknown-unknown"
6+
7+
define spir_func void @test(<8 x float> %in) {
8+
%res = tail call spir_func <4 x float> @_Z25__spirv_RoundFToTF32INTELDv8_f(<8 x float> %in)
9+
ret void
10+
}
11+
12+
declare spir_func <4 x float> @_Z25__spirv_RoundFToTF32INTELDv8_f(<8 x float>)
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
; RUN: llc -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_INTEL_tensor_float32_conversion %s -o - | FileCheck %s
2+
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_INTEL_tensor_float32_conversion %s -o - -filetype=obj | spirv-val %}
3+
4+
; RUN: not llc -O0 -mtriple=spirv32-unknown-unknown %s -o %t.spvt 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR
5+
; CHECK-ERROR: the builtin requires the following SPIR-V extension: SPV_INTEL_tensor_float32_conversion
6+
7+
; CHECK: OpCapability TensorFloat32RoundingINTEL
8+
; CHECK: OpExtension "SPV_INTEL_tensor_float32_conversion"
9+
10+
; CHECK-DAG: %[[VoidTy:.*]] = OpTypeVoid
11+
; CHECK-DAG: %[[FP32Ty:.*]] = OpTypeFloat 32
12+
; CHECK-DAG: %[[VecFloat2:.*]] = OpTypeVector %[[FP32Ty]] 2
13+
; CHECK-DAG: %[[VecFloat3:.*]] = OpTypeVector %[[FP32Ty]] 3
14+
; CHECK-DAG: %[[VecFloat4:.*]] = OpTypeVector %[[FP32Ty]] 4
15+
; CHECK-DAG: %[[VecFloat8:.*]] = OpTypeVector %[[FP32Ty]] 8
16+
; CHECK-DAG: %[[VecFloat16:.*]] = OpTypeVector %[[FP32Ty]] 16
17+
; CHECK-DAG: %[[FloatConstId:.*]] = OpConstant %[[FP32Ty]] 1.5
18+
19+
; CHECK: OpFunction %[[VoidTy]]
20+
; CHECK: %[[FP32ValId:.*]] = OpFunctionParameter %[[FP32Ty]]
21+
; CHECK: %[[FP32v8ValId:.*]] = OpFunctionParameter %[[VecFloat8]]
22+
; CHECK: OpRoundFToTF32INTEL %[[FP32Ty]] %[[FP32ValId]]
23+
; CHECK: OpRoundFToTF32INTEL %[[VecFloat8]] %[[FP32v8ValId]]
24+
; CHECK: OpRoundFToTF32INTEL %[[FP32Ty]] %[[FloatConstId]]
25+
26+
; CHECK: OpRoundFToTF32INTEL %[[FP32Ty]]
27+
; CHECK: OpRoundFToTF32INTEL %[[VecFloat2]]
28+
; CHECK: OpRoundFToTF32INTEL %[[VecFloat3]]
29+
; CHECK: OpRoundFToTF32INTEL %[[VecFloat4]]
30+
; CHECK: OpRoundFToTF32INTEL %[[VecFloat8]]
31+
; CHECK: OpRoundFToTF32INTEL %[[VecFloat16]]
32+
33+
target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
34+
target triple = "spir64-unknown-unknown"
35+
36+
define spir_func void @test(float %a, <8 x float> %in) {
37+
%res1 = tail call spir_func float @_Z25__spirv_RoundFToTF32INTELf(float %a)
38+
%res2 = tail call spir_func <8 x float> @_Z25__spirv_RoundFToTF32INTELDv8_f(<8 x float> %in)
39+
%res3 = tail call spir_func float @_Z25__spirv_RoundFToTF32INTELf(float 1.500000e+00)
40+
ret void
41+
}
42+
43+
declare spir_func float @_Z25__spirv_RoundFToTF32INTELf(float)
44+
declare spir_func <8 x float> @_Z25__spirv_RoundFToTF32INTELDv8_f(<8 x float>)
45+
46+
define dso_local spir_kernel void @test_ocl(float %a) {
47+
entry:
48+
%res4 = call spir_func float @_Z35intel_round_as_tensor_float32_floatt(float 0.000000e+00)
49+
%res5 = call spir_func <2 x float> @_Z37intel_round_as_tensor_float322_float2Dv2_t(<2 x float> zeroinitializer)
50+
%res6 = call spir_func <3 x float> @_Z37intel_round_as_tensor_float323_float3Dv3_t(<3 x float> zeroinitializer)
51+
%res7 = call spir_func <4 x float> @_Z37intel_round_as_tensor_float324_float4Dv4_t(<4 x float> zeroinitializer)
52+
%res8 = call spir_func <8 x float> @_Z37intel_round_as_tensor_float328_float8Dv8_t(<8 x float> zeroinitializer)
53+
%res9 = call spir_func <16 x float> @_Z39intel_round_as_tensor_float3216_float16Dv16_t(<16 x float> zeroinitializer)
54+
ret void
55+
}
56+
57+
declare spir_func float @_Z35intel_round_as_tensor_float32_floatt(float)
58+
declare spir_func <2 x float> @_Z37intel_round_as_tensor_float322_float2Dv2_t(<2 x float>)
59+
declare spir_func <3 x float> @_Z37intel_round_as_tensor_float323_float3Dv3_t(<3 x float>)
60+
declare spir_func <4 x float> @_Z37intel_round_as_tensor_float324_float4Dv4_t(<4 x float>)
61+
declare spir_func <8 x float> @_Z37intel_round_as_tensor_float328_float8Dv8_t(<8 x float>)
62+
declare spir_func <16 x float> @_Z39intel_round_as_tensor_float3216_float16Dv16_t(<16 x float>)

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -405,6 +405,7 @@ def SPV_INTEL_memory_access_aliasing : I32EnumAttrCase<"SPV_INTEL_me
405405
def SPV_INTEL_split_barrier : I32EnumAttrCase<"SPV_INTEL_split_barrier", 4029>;
406406
def SPV_INTEL_bfloat16_conversion : I32EnumAttrCase<"SPV_INTEL_bfloat16_conversion", 4031>;
407407
def SPV_INTEL_cache_controls : I32EnumAttrCase<"SPV_INTEL_cache_controls", 4032>;
408+
def SPV_INTEL_tensor_float32_conversion : I32EnumAttrCase<"SPV_INTEL_tensor_float32_conversion", 4033>;
408409

409410
def SPV_NV_compute_shader_derivatives : I32EnumAttrCase<"SPV_NV_compute_shader_derivatives", 5000>;
410411
def SPV_NV_cooperative_matrix : I32EnumAttrCase<"SPV_NV_cooperative_matrix", 5001>;
@@ -474,7 +475,8 @@ def SPIRV_ExtensionAttr :
474475
SPV_NV_shader_image_footprint, SPV_NV_shader_sm_builtins,
475476
SPV_NV_shader_subgroup_partitioned, SPV_NV_shading_rate,
476477
SPV_NV_stereo_view_rendering, SPV_NV_viewport_array2, SPV_NV_bindless_texture,
477-
SPV_NV_ray_tracing_motion_blur, SPV_NVX_multiview_per_view_attributes
478+
SPV_NV_ray_tracing_motion_blur, SPV_NVX_multiview_per_view_attributes,
479+
SPV_INTEL_tensor_float32_conversion
478480
]>;
479481

480482
//===----------------------------------------------------------------------===//
@@ -1465,6 +1467,12 @@ def SPIRV_C_Bfloat16ConversionINTEL : I32EnumAttrCase<"B
14651467
];
14661468
}
14671469

1470+
def SPIRV_C_TensorFloat32RoundingINTEL : I32EnumAttrCase<"TensorFloat32RoundingINTEL", 6425> {
1471+
list<Availability> availability = [
1472+
Extension<[SPV_INTEL_tensor_float32_conversion]>
1473+
];
1474+
}
1475+
14681476
def SPIRV_C_CacheControlsINTEL : I32EnumAttrCase<"CacheControlsINTEL", 6441> {
14691477
list<Availability> availability = [
14701478
Extension<[SPV_INTEL_cache_controls]>
@@ -1567,7 +1575,8 @@ def SPIRV_CapabilityAttr :
15671575
SPIRV_C_ShaderViewportIndexLayerEXT, SPIRV_C_ShaderViewportMaskNV,
15681576
SPIRV_C_ShaderStereoViewNV, SPIRV_C_Bfloat16ConversionINTEL,
15691577
SPIRV_C_CacheControlsINTEL, SPIRV_C_BFloat16TypeKHR,
1570-
SPIRV_C_BFloat16DotProductKHR, SPIRV_C_BFloat16CooperativeMatrixKHR
1578+
SPIRV_C_BFloat16DotProductKHR, SPIRV_C_BFloat16CooperativeMatrixKHR,
1579+
SPIRV_C_TensorFloat32RoundingINTEL
15711580
]>;
15721581

15731582
def SPIRV_AM_Logical : I32EnumAttrCase<"Logical", 0>;
@@ -4586,6 +4595,7 @@ def SPIRV_OC_OpControlBarrierArriveINTEL : I32EnumAttrCase<"OpControlBarrie
45864595
def SPIRV_OC_OpControlBarrierWaitINTEL : I32EnumAttrCase<"OpControlBarrierWaitINTEL", 6143>;
45874596
def SPIRV_OC_OpGroupIMulKHR : I32EnumAttrCase<"OpGroupIMulKHR", 6401>;
45884597
def SPIRV_OC_OpGroupFMulKHR : I32EnumAttrCase<"OpGroupFMulKHR", 6402>;
4598+
def SPIRV_OC_OpRoundFToTF32INTEL : I32EnumAttrCase<"OpRoundFToTF32INTEL", 6426>;
45894599

45904600
def SPIRV_OpcodeAttr :
45914601
SPIRV_I32EnumAttr<"Opcode", "valid SPIR-V instructions", "opcode", [
@@ -4690,7 +4700,8 @@ def SPIRV_OpcodeAttr :
46904700
SPIRV_OC_OpAssumeTrueKHR, SPIRV_OC_OpAtomicFAddEXT,
46914701
SPIRV_OC_OpConvertFToBF16INTEL, SPIRV_OC_OpConvertBF16ToFINTEL,
46924702
SPIRV_OC_OpControlBarrierArriveINTEL, SPIRV_OC_OpControlBarrierWaitINTEL,
4693-
SPIRV_OC_OpGroupIMulKHR, SPIRV_OC_OpGroupFMulKHR
4703+
SPIRV_OC_OpGroupIMulKHR, SPIRV_OC_OpGroupFMulKHR,
4704+
SPIRV_OC_OpRoundFToTF32INTEL
46944705
]>;
46954706

46964707
// End opcode section. Generated from SPIR-V spec; DO NOT MODIFY!

0 commit comments

Comments
 (0)