Skip to content

Commit 3749aa6

Browse files
draft implementation for supporting SPV_INTEL_tensor_float32_conversion
1 parent ee1ecf3 commit 3749aa6

File tree

5 files changed

+47
-4
lines changed

5 files changed

+47
-4
lines changed

llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ struct ConvertBuiltin {
148148
bool IsSaturated;
149149
bool IsRounded;
150150
bool IsBfloat16;
151+
bool IsTF32;
151152
FPRoundingMode::FPRoundingMode RoundingMode;
152153
};
153154

@@ -2677,8 +2678,20 @@ static bool generateConvertInst(const StringRef DemangledCall,
26772678
}
26782679
} else if (GR->isScalarOrVectorOfType(Call->ReturnRegister,
26792680
SPIRV::OpTypeFloat)) {
2680-
// Float -> Float
2681-
Opcode = SPIRV::OpFConvert;
2681+
if(Builtin->IsTF32){
2682+
const auto *ST = static_cast<const SPIRVSubtarget *>(
2683+
&MIRBuilder.getMF().getSubtarget());
2684+
if (!ST->canUseExtension(
2685+
SPIRV::Extension::SPV_INTEL_bfloat16_conversion))
2686+
NeedExtMsg = "SPV_INTEL_bfloat16_conversion";
2687+
IsRightComponentsNumber =
2688+
GR->getScalarOrVectorComponentCount(Call->Arguments[0]) ==
2689+
GR->getScalarOrVectorComponentCount(Call->ReturnRegister);
2690+
Opcode = SPIRV::OpRoundFToTF32INTEL;
2691+
} else {
2692+
Float -> Float
2693+
Opcode = SPIRV::OpFConvert;
2694+
}
26822695
}
26832696
}
26842697

llvm/lib/Target/SPIRV/SPIRVBuiltins.td

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1461,6 +1461,7 @@ class ConvertBuiltin<string name, InstructionSet set> {
14611461
bit IsRounded = !not(!eq(!find(name, "_rt"), -1));
14621462
bit IsBfloat16 = !or(!not(!eq(!find(name, "BF16"), -1)),
14631463
!not(!eq(!find(name, "bfloat16"), -1)));
1464+
bit IsTF32 = !not(!eq(!find(name, "TF32"), -1));
14641465
FPRoundingMode RoundingMode = !cond(!not(!eq(!find(name, "_rte"), -1)) : RTE,
14651466
!not(!eq(!find(name, "_rtz"), -1)) : RTZ,
14661467
!not(!eq(!find(name, "_rtp"), -1)) : RTP,
@@ -1472,7 +1473,7 @@ class ConvertBuiltin<string name, InstructionSet set> {
14721473
def ConvertBuiltins : GenericTable {
14731474
let FilterClass = "ConvertBuiltin";
14741475
let Fields = ["Name", "Set", "IsDestinationSigned", "IsSaturated",
1475-
"IsRounded", "IsBfloat16", "RoundingMode"];
1476+
"IsRounded", "IsBfloat16", "IsTF32", "RoundingMode"];
14761477
string TypeOf_Set = "InstructionSet";
14771478
string TypeOf_RoundingMode = "FPRoundingMode";
14781479
}
@@ -1556,6 +1557,24 @@ foreach conv = ["FToBF16INTEL", "BF16ToFINTEL"] in {
15561557
def : ConvertBuiltin<!strconcat("__spirv_Convert", conv), OpenCL_std>;
15571558
}
15581559

1560+
// SPV_INTEL_tensor_float32_conversion
1561+
// Multiclass used to define at the same time both a demangled builtin records
1562+
// and a corresponding convert builtin records.
1563+
multiclass DemangledTF32ConvertBuiltin<string name1, string name2> {
1564+
// Create records for scalar and vector conversions.
1565+
foreach i = ["", "2", "3", "4", "8", "16"] in {
1566+
def : DemangledBuiltin<!strconcat("intel_convert_", name1, i, name2, i), OpenCL_std, Convert, 1, 1>;
1567+
def : ConvertBuiltin<!strconcat("intel_convert_", name1, i, name2, i), OpenCL_std>;
1568+
}
1569+
}
1570+
1571+
defm : DemangledTF32ConvertBuiltin<"ConvertFToTF32INTEL">;
1572+
1573+
foreach conv = ["FToTF32INTEL"] in {
1574+
def : DemangledBuiltin<!strconcat("__spirv_Convert", conv), OpenCL_std, Convert, 1, 1>;
1575+
def : ConvertBuiltin<!strconcat("__spirv_Convert", conv), OpenCL_std>;
1576+
}
1577+
15591578
//===----------------------------------------------------------------------===//
15601579
// Class defining a vector data load/store builtin record used for lowering
15611580
// into OpExtInst instruction.

llvm/lib/Target/SPIRV/SPIRVInstrInfo.td

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -441,10 +441,13 @@ def OpBitcast : UnOp<"OpBitcast", 124>;
441441
def OpPtrCastToCrossWorkgroupINTEL : UnOp<"OpPtrCastToCrossWorkgroupINTEL", 5934>;
442442
def OpCrossWorkgroupCastToPtrINTEL : UnOp<"OpCrossWorkgroupCastToPtrINTEL", 5938>;
443443

444-
// SPV_INTEL_bfloat16_conversion
444+
// SPV_INTEL_tensor_float32_conversion
445445
def OpConvertFToBF16INTEL : UnOp<"OpConvertFToBF16INTEL", 6116>;
446446
def OpConvertBF16ToFINTEL : UnOp<"OpConvertBF16ToFINTEL", 6117>;
447447

448+
// SPV_INTEL_bfloat16_conversion
449+
def OpRoundFToTF32INTEL : UnOp<"OpRoundFToTF32INTEL", 6426>;
450+
448451
// 3.42.12 Composite Instructions
449452

450453
def OpVectorExtractDynamic: Op<77, (outs ID:$res), (ins TYPE:$type, vID:$vec, ID:$idx),

llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1564,6 +1564,12 @@ void addInstrRequirements(const MachineInstr &MI,
15641564
Reqs.addCapability(SPIRV::Capability::BFloat16ConversionINTEL);
15651565
}
15661566
break;
1567+
case SPIRV::OpRoundFToTF32INTEL:
1568+
if (ST.canUseExtension(SPIRV::Extension::SPV_INTEL_tensor_float32_conversion)) {
1569+
Reqs.addExtension(SPIRV::Extension::SPV_INTEL_tensor_float32_conversion);
1570+
Reqs.addCapability(SPIRV::Capability::TF32ConversionINTEL);
1571+
}
1572+
break;
15671573
case SPIRV::OpVariableLengthArrayINTEL:
15681574
case SPIRV::OpSaveMemoryINTEL:
15691575
case SPIRV::OpRestoreMemoryINTEL:

llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,7 @@ defm SPV_INTEL_subgroup_matrix_multiply_accumulate : ExtensionOperand<121>;
320320
defm SPV_INTEL_2d_block_io : ExtensionOperand<122>;
321321
defm SPV_INTEL_int4 : ExtensionOperand<123>;
322322
defm SPV_KHR_float_controls2 : ExtensionOperand<124>;
323+
defm SPV_INTEL_tensor_float32_conversion : ExtensionOperand<125>;
323324

324325
//===----------------------------------------------------------------------===//
325326
// Multiclass used to define Capabilities enum values and at the same time
@@ -502,6 +503,7 @@ defm VariableLengthArrayINTEL : CapabilityOperand<5817, 0, 0, [SPV_INTEL_variabl
502503
defm GroupUniformArithmeticKHR : CapabilityOperand<6400, 0, 0, [SPV_KHR_uniform_group_instructions], []>;
503504
defm USMStorageClassesINTEL : CapabilityOperand<5935, 0, 0, [SPV_INTEL_usm_storage_classes], [Kernel]>;
504505
defm BFloat16ConversionINTEL : CapabilityOperand<6115, 0, 0, [SPV_INTEL_bfloat16_conversion], []>;
506+
defm TF32ConversionINTEL : CapabilityOperand<6425, 0, 0, [SPV_INTEL_tensor_float32_conversion], []>;
505507
defm GlobalVariableHostAccessINTEL : CapabilityOperand<6187, 0, 0, [SPV_INTEL_global_variable_host_access], []>;
506508
defm HostAccessINTEL : CapabilityOperand<6188, 0, 0, [SPV_INTEL_global_variable_host_access], []>;
507509
defm GlobalVariableFPGADecorationsINTEL : CapabilityOperand<6189, 0, 0, [SPV_INTEL_global_variable_fpga_decorations], []>;

0 commit comments

Comments
 (0)