Skip to content

Commit 6061b72

Browse files
authored
Add initial support for SPV_EXT_float8 (KhronosGroup#6170)
* Add initial support for SPV_EXT_float8 This also overhauls the number parsing code. It seems further work will be needed to correctly handle BFloat16. This change intends to leave the code no worse than it was. TODOs were added where relevant. Change-Id: I0e146a949ae68d0488dabbb91191f2801ba4b49c Signed-off-by: Guillaume Trebuchet <[email protected]> Signed-off-by: Kevin Petit <[email protected]> * attempt to fix MSVC warning Change-Id: I630635c54393d7ffc82001f2eec0d2538300521e * another attempt Change-Id: I237f6c8ca054347a8e5da62f7649f6fb1d76cfa1 * address review comments Change-Id: Ic0be831943f040623860ccc4dc2321c3616f10d1 * use more descriptive FP encoding enum names for FP8 Change-Id: I8ba05748480a030e0246fba14996fb123b1f146b * add a few more validator tests Change-Id: Idfa90b67d69cabcd092c5cc1dbe06f553946c73a * format fixes Change-Id: I1e90212170eaa547019e6bd3672b64ccc2ab3871 --------- Signed-off-by: Guillaume Trebuchet <[email protected]> Signed-off-by: Kevin Petit <[email protected]>
1 parent 7dda3c0 commit 6061b72

28 files changed

+1740
-63
lines changed

DEPS

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ vars = {
1414

1515
're2_revision': 'c84a140c93352cdabbfb547c531be34515b12228',
1616

17-
'spirv_headers_revision': '7168a5ad041f6b6b9170f027c7417f98a2056ff0',
17+
'spirv_headers_revision': 'fd96661925488574fe247a779babe5d380b63635',
1818
}
1919

2020
deps = {

include/spirv-tools/libspirv.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -390,6 +390,18 @@ typedef enum spv_number_kind_t {
390390
SPV_NUMBER_FLOATING,
391391
} spv_number_kind_t;
392392

393+
// Represent the encoding of floating point values
394+
typedef enum spv_fp_encoding_t {
395+
SPV_FP_ENCODING_UNKNOWN =
396+
0, // The encoding is not specified. Has to be deduced from bitwidth
397+
SPV_FP_ENCODING_IEEE754_BINARY16, // half float
398+
SPV_FP_ENCODING_IEEE754_BINARY32, // single float
399+
SPV_FP_ENCODING_IEEE754_BINARY64, // double float
400+
SPV_FP_ENCODING_BFLOAT16,
401+
SPV_FP_ENCODING_FLOAT8_E4M3,
402+
SPV_FP_ENCODING_FLOAT8_E5M2,
403+
} spv_fp_encoding_t;
404+
393405
typedef enum spv_text_to_binary_options_t {
394406
SPV_TEXT_TO_BINARY_OPTION_NONE = SPV_BIT(0),
395407
// Numeric IDs in the binary will have the same values as in the source.
@@ -445,6 +457,8 @@ typedef struct spv_parsed_operand_t {
445457
spv_number_kind_t number_kind;
446458
// The number of bits for a literal number type.
447459
uint32_t number_bit_width;
460+
// The encoding used for floating point values
461+
spv_fp_encoding_t fp_encoding;
448462
} spv_parsed_operand_t;
449463

450464
// An instruction parsed from a binary SPIR-V module.

source/binary.cpp

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,7 @@ class Parser {
190190
struct NumberType {
191191
spv_number_kind_t type;
192192
uint32_t bit_width;
193+
spv_fp_encoding_t encoding;
193194
};
194195

195196
// The state used to parse a single SPIR-V binary module.
@@ -385,8 +386,6 @@ spv_result_t Parser::parseInstruction() {
385386
assert(_.requires_endian_conversion ||
386387
(_.endian_converted_words.size() == 1));
387388

388-
recordNumberType(inst_offset, &inst);
389-
390389
if (_.requires_endian_conversion) {
391390
// We must wait until here to set this pointer, because the vector might
392391
// have been be resized while we accumulated its elements.
@@ -398,6 +397,8 @@ spv_result_t Parser::parseInstruction() {
398397
}
399398
inst.num_words = inst_word_count;
400399

400+
recordNumberType(inst_offset, &inst);
401+
401402
// We must wait until here to set this pointer, because the vector might
402403
// have been be resized while we accumulated its elements.
403404
inst.operands = _.operands.data();
@@ -833,6 +834,7 @@ spv_result_t Parser::setNumericTypeInfoForType(
833834

834835
parsed_operand->number_kind = info.type;
835836
parsed_operand->number_bit_width = info.bit_width;
837+
parsed_operand->fp_encoding = info.encoding;
836838
// Round up the word count.
837839
parsed_operand->num_words = static_cast<uint16_t>((info.bit_width + 31) / 32);
838840
return SPV_SUCCESS;
@@ -850,6 +852,17 @@ void Parser::recordNumberType(size_t inst_offset,
850852
} else if (spv::Op::OpTypeFloat == opcode) {
851853
info.type = SPV_NUMBER_FLOATING;
852854
info.bit_width = peekAt(inst_offset + 2);
855+
if (inst->num_words >= 4) {
856+
const spvtools::OperandDesc* desc;
857+
spv_result_t status = spvtools::LookupOperand(
858+
SPV_OPERAND_TYPE_FPENCODING, peekAt(inst_offset + 3), &desc);
859+
if (status == SPV_SUCCESS) {
860+
info.encoding = spvFPEncodingFromOperandFPEncoding(
861+
static_cast<spv::FPEncoding>(desc->value));
862+
} else {
863+
info.encoding = SPV_FP_ENCODING_UNKNOWN;
864+
}
865+
}
853866
}
854867
// The *result* Id of a type generating instruction is the type Id.
855868
_.type_id_to_number_type_info[inst->result_id] = info;

source/name_mapper.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,14 @@ spv_result_t FriendlyNameMapper::ParseInstruction(
217217
SaveName(result_id, "bfloat16");
218218
break;
219219
}
220+
if (spv::FPEncoding(inst.words[3]) == spv::FPEncoding::Float8E4M3EXT) {
221+
SaveName(result_id, "fp8e4m3");
222+
break;
223+
}
224+
if (spv::FPEncoding(inst.words[3]) == spv::FPEncoding::Float8E5M2EXT) {
225+
SaveName(result_id, "fp8e5m2");
226+
break;
227+
}
220228
}
221229
switch (bit_width) {
222230
case 16:

source/operand.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -619,3 +619,17 @@ std::function<bool(unsigned)> spvDbgInfoExtOperandCanBeForwardDeclaredFunction(
619619
}
620620
return out;
621621
}
622+
623+
spv_fp_encoding_t spvFPEncodingFromOperandFPEncoding(spv::FPEncoding encoding) {
624+
switch (encoding) {
625+
case spv::FPEncoding::BFloat16KHR:
626+
return SPV_FP_ENCODING_BFLOAT16;
627+
case spv::FPEncoding::Float8E4M3EXT:
628+
return SPV_FP_ENCODING_FLOAT8_E4M3;
629+
case spv::FPEncoding::Float8E5M2EXT:
630+
return SPV_FP_ENCODING_FLOAT8_E5M2;
631+
case spv::FPEncoding::Max:
632+
break;
633+
}
634+
return SPV_FP_ENCODING_UNKNOWN;
635+
}

source/operand.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,4 +122,7 @@ std::function<bool(unsigned)> spvOperandCanBeForwardDeclaredFunction(
122122
std::function<bool(unsigned)> spvDbgInfoExtOperandCanBeForwardDeclaredFunction(
123123
spv::Op opcode, spv_ext_inst_type_t ext_type, uint32_t key);
124124

125+
// Converts an spv::FPEncoding to spv_fp_encoding_t
126+
spv_fp_encoding_t spvFPEncodingFromOperandFPEncoding(spv::FPEncoding encoding);
127+
125128
#endif // SOURCE_OPERAND_H_

source/opt/types.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,14 @@ std::string Float::str() const {
320320
assert(width_ == 16);
321321
oss << "bfloat16";
322322
break;
323+
case spv::FPEncoding::Float8E4M3EXT:
324+
assert(width_ == 8);
325+
oss << "fp8e4m3";
326+
break;
327+
case spv::FPEncoding::Float8E5M2EXT:
328+
assert(width_ == 8);
329+
oss << "fp8e5m2";
330+
break;
323331
default:
324332
oss << "float" << width_;
325333
break;

source/parsed_operand.cpp

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -43,12 +43,35 @@ void EmitNumericLiteral(std::ostream* out, const spv_parsed_instruction_t& inst,
4343
*out << word;
4444
break;
4545
case SPV_NUMBER_FLOATING:
46-
if (operand.number_bit_width == 16) {
47-
*out << spvtools::utils::FloatProxy<spvtools::utils::Float16>(
48-
uint16_t(word & 0xFFFF));
49-
} else {
50-
// Assume 32-bit floats.
51-
*out << spvtools::utils::FloatProxy<float>(word);
46+
switch (operand.fp_encoding) {
47+
case SPV_FP_ENCODING_IEEE754_BINARY16:
48+
*out << spvtools::utils::FloatProxy<spvtools::utils::Float16>(
49+
uint16_t(word & 0xFFFF));
50+
break;
51+
case SPV_FP_ENCODING_IEEE754_BINARY32:
52+
*out << spvtools::utils::FloatProxy<float>(word);
53+
break;
54+
case SPV_FP_ENCODING_FLOAT8_E4M3:
55+
*out << spvtools::utils::FloatProxy<spvtools::utils::Float8_E4M3>(
56+
uint8_t(word & 0xFF));
57+
break;
58+
case SPV_FP_ENCODING_FLOAT8_E5M2:
59+
*out << spvtools::utils::FloatProxy<spvtools::utils::Float8_E5M2>(
60+
uint8_t(word & 0xFF));
61+
break;
62+
// TODO Bfloat16
63+
case SPV_FP_ENCODING_UNKNOWN:
64+
switch (operand.number_bit_width) {
65+
case 16:
66+
*out << spvtools::utils::FloatProxy<spvtools::utils::Float16>(
67+
uint16_t(word & 0xFFFF));
68+
break;
69+
case 32:
70+
*out << spvtools::utils::FloatProxy<float>(word);
71+
break;
72+
}
73+
default:
74+
break;
5275
}
5376
break;
5477
default:

source/text_handler.cpp

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -254,13 +254,13 @@ spv_result_t AssemblyContext::binaryEncodeNumericLiteral(
254254
<< "Unexpected numeric literal type";
255255
case IdTypeClass::kScalarIntegerType:
256256
if (type.isSigned) {
257-
number_type = {type.bitwidth, SPV_NUMBER_SIGNED_INT};
257+
number_type = {type.bitwidth, SPV_NUMBER_SIGNED_INT, type.encoding};
258258
} else {
259-
number_type = {type.bitwidth, SPV_NUMBER_UNSIGNED_INT};
259+
number_type = {type.bitwidth, SPV_NUMBER_UNSIGNED_INT, type.encoding};
260260
}
261261
break;
262262
case IdTypeClass::kScalarFloatType:
263-
number_type = {type.bitwidth, SPV_NUMBER_FLOATING};
263+
number_type = {type.bitwidth, SPV_NUMBER_FLOATING, type.encoding};
264264
break;
265265
case IdTypeClass::kBottom:
266266
// kBottom means the type is unknown and we need to infer the type before
@@ -270,11 +270,11 @@ spv_result_t AssemblyContext::binaryEncodeNumericLiteral(
270270
// signed integer, otherwise an unsigned integer.
271271
uint32_t bitwidth = static_cast<uint32_t>(assumedBitWidth(type));
272272
if (strchr(val, '.')) {
273-
number_type = {bitwidth, SPV_NUMBER_FLOATING};
273+
number_type = {bitwidth, SPV_NUMBER_FLOATING, type.encoding};
274274
} else if (type.isSigned || val[0] == '-') {
275-
number_type = {bitwidth, SPV_NUMBER_SIGNED_INT};
275+
number_type = {bitwidth, SPV_NUMBER_SIGNED_INT, type.encoding};
276276
} else {
277-
number_type = {bitwidth, SPV_NUMBER_UNSIGNED_INT};
277+
number_type = {bitwidth, SPV_NUMBER_UNSIGNED_INT, type.encoding};
278278
}
279279
break;
280280
}
@@ -330,14 +330,27 @@ spv_result_t AssemblyContext::recordTypeDefinition(
330330
if (pInst->words.size() != 4)
331331
return diagnostic() << "Invalid OpTypeInt instruction";
332332
types_[value] = {pInst->words[2], pInst->words[3] != 0,
333-
IdTypeClass::kScalarIntegerType};
333+
IdTypeClass::kScalarIntegerType, SPV_FP_ENCODING_UNKNOWN};
334334
} else if (pInst->opcode == spv::Op::OpTypeFloat) {
335335
if ((pInst->words.size() != 3) && (pInst->words.size() != 4))
336336
return diagnostic() << "Invalid OpTypeFloat instruction";
337-
// TODO(kpet) Do we need to record the FP Encoding here?
338-
types_[value] = {pInst->words[2], false, IdTypeClass::kScalarFloatType};
337+
spv_fp_encoding_t enc = SPV_FP_ENCODING_UNKNOWN;
338+
if (pInst->words.size() >= 4) {
339+
const spvtools::OperandDesc* desc;
340+
spv_result_t status = spvtools::LookupOperand(SPV_OPERAND_TYPE_FPENCODING,
341+
pInst->words[3], &desc);
342+
if (status == SPV_SUCCESS) {
343+
enc = spvFPEncodingFromOperandFPEncoding(
344+
static_cast<spv::FPEncoding>(desc->value));
345+
} else {
346+
return diagnostic() << "Invalid OpTypeFloat encoding";
347+
}
348+
}
349+
types_[value] = {pInst->words[2], false, IdTypeClass::kScalarFloatType,
350+
enc};
339351
} else {
340-
types_[value] = {0, false, IdTypeClass::kOtherType};
352+
types_[value] = {0, false, IdTypeClass::kOtherType,
353+
SPV_FP_ENCODING_UNKNOWN};
341354
}
342355
return SPV_SUCCESS;
343356
}

source/text_handler.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ struct IdType {
4747
uint32_t bitwidth; // Safe to assume that we will not have > 2^32 bits.
4848
bool isSigned; // This is only significant if type_class is integral.
4949
IdTypeClass type_class;
50+
spv_fp_encoding_t encoding;
5051
};
5152

5253
// Default equality operator for IdType. Tests if all members are the same.

0 commit comments

Comments
 (0)