Skip to content

Commit 2e7ba02

Browse files
AMD-dwangstu-sdneto0
authored
Add SPV_KHR_bfloat16 support (KhronosGroup#6057)
* Add SPV_KHR_bfloat16 support * Update DEPS to include SPIRV-Headers with bfloat16 support * Fix unit test errors and format * Add validation to invalid uses of bfloat16 * Add tests * Roll back to previous commit * Fix build error * Add FPEncoding for opt::analysis::Float * Address the comments * Fix build error * format --------- Co-authored-by: Stu Smith <[email protected]> Co-authored-by: David Neto <[email protected]>
1 parent e940239 commit 2e7ba02

24 files changed

+575
-19
lines changed

Android.mk

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,8 @@ SPVTOOLS_SRC_FILES := \
7676
source/val/validate_scopes.cpp \
7777
source/val/validate_small_type_uses.cpp \
7878
source/val/validate_tensor_layout.cpp \
79-
source/val/validate_type.cpp
79+
source/val/validate_type.cpp\
80+
source/val/validate_invalid_type.cpp
8081

8182
SPVTOOLS_OPT_SRC_FILES := \
8283
source/opt/aggressive_dead_code_elim_pass.cpp \

BUILD.gn

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -559,6 +559,7 @@ static_library("spvtools_val") {
559559
"source/val/validate_small_type_uses.cpp",
560560
"source/val/validate_tensor_layout.cpp",
561561
"source/val/validate_type.cpp",
562+
"source/val/validate_invalid_type.cpp",
562563
"source/val/validation_state.cpp",
563564
"source/val/validation_state.h",
564565
]

source/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,7 @@ set(SPIRV_SOURCES
336336
${CMAKE_CURRENT_SOURCE_DIR}/val/validate_small_type_uses.cpp
337337
${CMAKE_CURRENT_SOURCE_DIR}/val/validate_tensor_layout.cpp
338338
${CMAKE_CURRENT_SOURCE_DIR}/val/validate_type.cpp
339+
${CMAKE_CURRENT_SOURCE_DIR}/val/validate_invalid_type.cpp
339340
${CMAKE_CURRENT_SOURCE_DIR}/val/decoration.h
340341
${CMAKE_CURRENT_SOURCE_DIR}/val/basic_block.cpp
341342
${CMAKE_CURRENT_SOURCE_DIR}/val/construct.cpp

source/name_mapper.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,12 @@ spv_result_t FriendlyNameMapper::ParseInstruction(
211211
} break;
212212
case spv::Op::OpTypeFloat: {
213213
const auto bit_width = inst.words[2];
214-
// TODO: Handle optional fpencoding enum once actually used.
214+
if (inst.num_words > 3) {
215+
if (spv::FPEncoding(inst.words[3]) == spv::FPEncoding::BFloat16KHR) {
216+
SaveName(result_id, "bfloat16");
217+
break;
218+
}
219+
}
215220
switch (bit_width) {
216221
case 16:
217222
SaveName(result_id, "half");

source/opt/type_manager.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -792,9 +792,13 @@ Type* TypeManager::RecordIfTypeDefinition(const Instruction& inst) {
792792
type = new Integer(inst.GetSingleWordInOperand(0),
793793
inst.GetSingleWordInOperand(1));
794794
break;
795-
case spv::Op::OpTypeFloat:
796-
type = new Float(inst.GetSingleWordInOperand(0));
797-
break;
795+
case spv::Op::OpTypeFloat: {
796+
const spv::FPEncoding encoding =
797+
inst.NumInOperands() > 1
798+
? static_cast<spv::FPEncoding>(inst.GetSingleWordInOperand(1))
799+
: spv::FPEncoding::Max;
800+
type = new Float(inst.GetSingleWordInOperand(0), encoding);
801+
} break;
798802
case spv::Op::OpTypeVector:
799803
type = new Vector(GetType(inst.GetSingleWordInOperand(0)),
800804
inst.GetSingleWordInOperand(1));

source/opt/types.cpp

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -309,17 +309,26 @@ size_t Integer::ComputeExtraStateHash(size_t hash, SeenTypes*) const {
309309

310310
bool Float::IsSameImpl(const Type* that, IsSameCache*) const {
311311
const Float* ft = that->AsFloat();
312-
return ft && width_ == ft->width_ && HasSameDecorations(that);
312+
return ft && width_ == ft->width_ && encoding_ == ft->encoding_ &&
313+
HasSameDecorations(that);
313314
}
314315

315316
std::string Float::str() const {
316317
std::ostringstream oss;
317-
oss << "float" << width_;
318+
switch (encoding_) {
319+
case spv::FPEncoding::BFloat16KHR:
320+
assert(width_ == 16);
321+
oss << "bfloat16";
322+
break;
323+
default:
324+
oss << "float" << width_;
325+
break;
326+
}
318327
return oss.str();
319328
}
320329

321330
size_t Float::ComputeExtraStateHash(size_t hash, SeenTypes*) const {
322-
return hash_combine(hash, width_);
331+
return hash_combine(hash, width_, encoding_);
323332
}
324333

325334
Vector::Vector(const Type* type, uint32_t count)

source/opt/types.h

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -182,9 +182,9 @@ class Type {
182182
// non-composite type.
183183
uint64_t NumberOfComponents() const;
184184

185-
// A bunch of methods for casting this type to a given type. Returns this if the
186-
// cast can be done, nullptr otherwise.
187-
// clang-format off
185+
// A bunch of methods for casting this type to a given type. Returns this if
186+
// the cast can be done, nullptr otherwise.
187+
// clang-format off
188188
#define DeclareCastMethod(target) \
189189
virtual target* As##target() { return nullptr; } \
190190
virtual const target* As##target() const { return nullptr; }
@@ -267,21 +267,24 @@ class Integer : public Type {
267267

268268
class Float : public Type {
269269
public:
270-
Float(uint32_t w) : Type(kFloat), width_(w) {}
270+
Float(uint32_t w, spv::FPEncoding encoding = spv::FPEncoding::Max)
271+
: Type(kFloat), width_(w), encoding_(encoding) {}
271272
Float(const Float&) = default;
272273

273274
std::string str() const override;
274275

275276
Float* AsFloat() override { return this; }
276277
const Float* AsFloat() const override { return this; }
277278
uint32_t width() const { return width_; }
279+
spv::FPEncoding encoding() const { return encoding_; }
278280

279281
size_t ComputeExtraStateHash(size_t hash, SeenTypes* seen) const override;
280282

281283
private:
282284
bool IsSameImpl(const Type* that, IsSameCache*) const override;
283285

284-
uint32_t width_; // bit width
286+
uint32_t width_; // bit width
287+
spv::FPEncoding encoding_; // FPEncoding
285288
};
286289

287290
class Vector : public Type {

source/val/validate.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,7 @@ spv_result_t ValidateBinaryUsingContextAndValidationState(
367367
if (auto error = RayReorderNVPass(*vstate, &instruction)) return error;
368368
if (auto error = MeshShadingPass(*vstate, &instruction)) return error;
369369
if (auto error = TensorLayoutPass(*vstate, &instruction)) return error;
370+
if (auto error = InvalidTypePass(*vstate, &instruction)) return error;
370371
}
371372

372373
// Validate the preconditions involving adjacent instructions. e.g.

source/val/validate.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,9 @@ spv_result_t RayReorderNVPass(ValidationState_t& _, const Instruction* inst);
223223
/// Validates correctness of mesh shading instructions.
224224
spv_result_t MeshShadingPass(ValidationState_t& _, const Instruction* inst);
225225

226+
/// Validates correctness of certain special type instructions.
227+
spv_result_t InvalidTypePass(ValidationState_t& _, const Instruction* inst);
228+
226229
/// Calculates the reachability of basic blocks.
227230
void ReachabilityPass(ValidationState_t& _);
228231

source/val/validate_arithmetics.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,14 @@ spv_result_t ArithmeticsPass(ValidationState_t& _, const Instruction* inst) {
224224
<< "Expected float scalar type as Result Type: "
225225
<< spvOpcodeString(opcode);
226226

227+
if (_.IsBfloat16ScalarType(result_type)) {
228+
if (!_.HasCapability(spv::Capability::BFloat16DotProductKHR)) {
229+
return _.diag(SPV_ERROR_INVALID_DATA, inst)
230+
<< "OpDot Result Type <id> " << _.getIdName(result_type)
231+
<< "requires BFloat16DotProductKHR be declared.";
232+
}
233+
}
234+
227235
uint32_t first_vector_num_components = 0;
228236

229237
for (size_t operand_index = 2; operand_index < inst->operands().size();

0 commit comments

Comments
 (0)