|
| 1 | +=================================================================================== |
| 2 | +SPIR-V representation in LLVM IR for FP8, FP4 and Int4 datatypes |
| 3 | +=================================================================================== |
| 4 | +.. contents:: |
| 5 | + :local: |
| 6 | + |
| 7 | +Overview |
| 8 | +======== |
| 9 | + |
| 10 | +Open Compute and other projects are adding various new data-types and SPIR-V |
| 11 | +(starting from SPV_EXT_float8) is now adopting them. None of these data-types |
| 12 | +have appropriate LLVM IR counterparts. This document describes the proposed |
| 13 | +LLVM IR input format for *FP8*, *FP4*, and *Int4* types, the translation flow, and the |
| 14 | +expected LLVM IR output from the consumer. |
| 15 | + |
| 16 | +SPIR-V Non-standard Types Mapped to LLVM Types |
| 17 | +============================================== |
| 18 | + |
| 19 | +All formats of *FP8* (E4M3, E5M2), *FP4* (E2M1), and *Int4* will be represented in LLVM IR with |
| 20 | +integer types (*i8* for FP8, *i4* for FP4 and Int4). |
| 21 | +Until 'type resolution' instruction appears in the module (see below), these values will |
| 22 | +remain as integers. When 'type resolution' instruction is being processed, integer values will be bitcasted |
| 23 | +to floating-point or integer values with appropriate width and encoding depending on the instruction. If the instruction's |
| 24 | +result is *FP8*, *FP4*, *Int4*, or a composite containing them, then it is also being bitcasted to the |
| 25 | +appropriate integer type or composite. It is safe to do as these extensions don't add support |
| 26 | +for arithmetic instructions and builtins (unless it's *OpCooperativeMatrixMulAddKHR*, but this |
| 27 | +case will be handled separately). |
| 28 | + |
| 29 | +The 'type resolution' instruction can be either a conversion instruction or *OpCooperativeMatrixMulAddKHR*. |
| 30 | + |
| 31 | +**Type mappings:** |
| 32 | + |
| 33 | +* FP8 (E4M3, E5M2) → *i8* |
| 34 | +* FP4 (E2M1) → *i4* |
| 35 | +* Int4 → *i4* |
| 36 | + |
| 37 | +SPIR-V conversion instructions |
| 38 | +============================== |
| 39 | + |
| 40 | +Most conversions will be represented by standard SPIR-V conversion instructions (*OpFConvert*, *OpConvertSToF*, *OpConvertFToS*, |
| 41 | +*OpConvertUToF*, *OpConvertFToU*, *OpSConvert*), which don't carry information about floating-point value's width and encoding. |
| 42 | +This document adds a new set of external function calls, each of which has a name that is formed from encoding a specific conversion |
| 43 | +that it performs. This name has a *__builtin_spirv_* prefix and a postfix indicating the extension (e.g., *EXT* from SPV_EXT_float8, |
| 44 | +*INTEL* from SPV_INTEL_int4/SPV_INTEL_float4/SPV_INTEL_fp_conversions). These calls will be translated to SPIR-V conversion |
| 45 | +instructions operating over the appropriate types. These functions are expected to be mangled following Itanium C++ ABI. SPIR-V consumer |
| 46 | +will apply Itanium mangling during translation to LLVM IR as well. |
| 47 | + |
| 48 | +SPIR-V generator will support *scalar*, *vector* and *packed* for the conversion builtin functions as LLVM IR input; |
| 49 | +*packed* format is translated to a *vector*. Meanwhile SPIR-V consumer will never pack a *vector* back to *packed* format. |
| 50 | + |
| 51 | +SPV_EXT_float8 and SPV_KHR_bfloat16 Conversions |
| 52 | +------------------------------------------------ |
| 53 | + |
| 54 | +**Translated to OpFConvert:** |
| 55 | + |
| 56 | +.. code-block:: C |
| 57 | +
|
| 58 | + __builtin_spirv_ConvertFP16ToE4M3EXT, __builtin_spirv_ConvertBF16ToE4M3EXT, |
| 59 | + __builtin_spirv_ConvertFP16ToE5M2EXT, __builtin_spirv_ConvertBF16ToE5M2EXT, |
| 60 | + __builtin_spirv_ConvertE4M3ToFP16EXT, __builtin_spirv_ConvertE5M2ToFP16EXT, |
| 61 | + __builtin_spirv_ConvertE4M3ToBF16EXT, __builtin_spirv_ConvertE5M2ToBF16EXT |
| 62 | +
|
| 63 | +SPV_INTEL_int4 Conversions |
| 64 | +--------------------------- |
| 65 | + |
| 66 | +**Translated to OpConvertSToF:** |
| 67 | + |
| 68 | +.. code-block:: C |
| 69 | +
|
| 70 | + __builtin_spirv_ConvertInt4ToE4M3INTEL, __builtin_spirv_ConvertInt4ToE5M2INTEL, |
| 71 | + __builtin_spirv_ConvertInt4ToFP16INTEL, __builtin_spirv_ConvertInt4ToBF16INTEL |
| 72 | +
|
| 73 | +**Translated to OpConvertFToS:** |
| 74 | + |
| 75 | +.. code-block:: C |
| 76 | +
|
| 77 | + __builtin_spirv_ConvertFP16ToInt4INTEL, __builtin_spirv_ConvertBF16ToInt4INTEL |
| 78 | +
|
| 79 | +**Translated to OpSConvert:** |
| 80 | + |
| 81 | +.. code-block:: C |
| 82 | +
|
| 83 | + __builtin_spirv_ConvertInt4ToInt8INTEL |
| 84 | +
|
| 85 | +SPV_INTEL_float4 Conversions |
| 86 | +----------------------------- |
| 87 | + |
| 88 | +**Translated to OpFConvert:** |
| 89 | + |
| 90 | +.. code-block:: C |
| 91 | +
|
| 92 | + __builtin_spirv_ConvertE2M1ToE4M3INTEL, __builtin_spirv_ConvertE2M1ToE5M2INTEL, |
| 93 | + __builtin_spirv_ConvertE2M1ToFP16INTEL, __builtin_spirv_ConvertE2M1ToBF16INTEL, |
| 94 | + __builtin_spirv_ConvertFP16ToE2M1INTEL, __builtin_spirv_ConvertBF16ToE2M1INTEL |
| 95 | +
|
| 96 | +SPV_INTEL_fp_conversions |
| 97 | +------------------------- |
| 98 | + |
| 99 | +This extension provides conversions with specialized rounding modes for improved precision and efficiency. |
| 100 | + |
| 101 | +**Translated to OpClampConvertFToFINTEL (clamp rounding):** |
| 102 | + |
| 103 | +.. code-block:: C |
| 104 | +
|
| 105 | + __builtin_spirv_ClampConvertFP16ToE2M1INTEL, __builtin_spirv_ClampConvertBF16ToE2M1INTEL, |
| 106 | + __builtin_spirv_ClampConvertFP16ToE4M3INTEL, __builtin_spirv_ClampConvertBF16ToE4M3INTEL, |
| 107 | + __builtin_spirv_ClampConvertFP16ToE5M2INTEL, __builtin_spirv_ClampConvertBF16ToE5M2INTEL |
| 108 | +
|
| 109 | +**Translated to OpClampConvertFToSINTEL (clamp rounding to signed integer):** |
| 110 | + |
| 111 | +.. code-block:: C |
| 112 | +
|
| 113 | + __builtin_spirv_ClampConvertFP16ToInt4INTEL, __builtin_spirv_ClampConvertBF16ToInt4INTEL |
| 114 | +
|
| 115 | +**Translated to OpStochasticRoundFToFINTEL (stochastic rounding):** |
| 116 | + |
| 117 | +.. code-block:: C |
| 118 | +
|
| 119 | + __builtin_spirv_StochasticRoundFP16ToE5M2INTEL, __builtin_spirv_StochasticRoundFP16ToE4M3INTEL, |
| 120 | + __builtin_spirv_StochasticRoundBF16ToE5M2INTEL, __builtin_spirv_StochasticRoundBF16ToE4M3INTEL, |
| 121 | + __builtin_spirv_StochasticRoundFP16ToE2M1INTEL, __builtin_spirv_StochasticRoundBF16ToE2M1INTEL |
| 122 | +
|
| 123 | +Note: These functions take an additional seed parameter (i32) and may optionally take a pointer parameter |
| 124 | +for storing the last seed value. |
| 125 | + |
| 126 | +**Translated to OpClampStochasticRoundFToSINTEL (clamp + stochastic rounding to signed integer):** |
| 127 | + |
| 128 | +.. code-block:: C |
| 129 | +
|
| 130 | + __builtin_spirv_ClampStochasticRoundFP16ToInt4INTEL, __builtin_spirv_ClampStochasticRoundBF16ToInt4INTEL |
| 131 | +
|
| 132 | +**Translated to OpClampStochasticRoundFToFINTEL (clamp + stochastic rounding):** |
| 133 | + |
| 134 | +.. code-block:: C |
| 135 | +
|
| 136 | + __builtin_spirv_ClampStochasticRoundFP16ToE5M2INTEL, __builtin_spirv_ClampStochasticRoundFP16ToE4M3INTEL, |
| 137 | + __builtin_spirv_ClampStochasticRoundBF16ToE5M2INTEL, __builtin_spirv_ClampStochasticRoundBF16ToE4M3INTEL |
| 138 | +
|
| 139 | +
|
| 140 | +Example LLVM IR to SPIR-V translation: |
| 141 | +Input LLVM IR |
| 142 | + |
| 143 | +.. code-block:: C |
| 144 | +
|
| 145 | + %alloc = alloca half |
| 146 | + %FP16_val = call half __builtin_spirv_ConvertE4M3ToFP16EXT(i8 1) |
| 147 | + store half %FP16_val, ptr %alloc |
| 148 | +
|
| 149 | +Output SPIR-V |
| 150 | + |
| 151 | +.. code-block:: C |
| 152 | +
|
| 153 | + %half_ty = OpTypeFloat 16 0 |
| 154 | + %ptr_ty = OpTypePointer %half_ty Private |
| 155 | + %int8_ty = OpTypeInt 8 0 |
| 156 | + %fp8_ty = OpTypeFloat 8 1 |
| 157 | + %const = OpConstant %int8_ty 1 |
| 158 | + /*...*/ |
| 159 | + %alloc = OpVariable %half_ty Private |
| 160 | + %fp8_val = OpBitCast %fp8_ty %const |
| 161 | + %fp16_val = OpFConvert %half_ty %fp8_val |
| 162 | + OpStore %fp16_val %alloc |
| 163 | +
|
| 164 | +Output LLVM IR |
| 165 | + |
| 166 | +.. code-block:: C |
| 167 | +
|
| 168 | + %alloc = alloca half |
| 169 | + %fp16_val = call half __builtin_spirv_ConvertE4M3ToFP16EXT(i8 1) |
| 170 | + store half %fp16_val, ptr %alloc |
| 171 | +
|
| 172 | +SPIR-V cooperative matrix instructions |
| 173 | +====================================== |
| 174 | + |
| 175 | +TBD |
0 commit comments