diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td index 8fd533db83d9a..19ef961d6c0e7 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td @@ -344,6 +344,7 @@ def SPV_KHR_subgroup_rotate : I32EnumAttrCase<"SPV_KHR_subgroup def SPV_KHR_non_semantic_info : I32EnumAttrCase<"SPV_KHR_non_semantic_info", 29>; def SPV_KHR_terminate_invocation : I32EnumAttrCase<"SPV_KHR_terminate_invocation", 30>; def SPV_KHR_cooperative_matrix : I32EnumAttrCase<"SPV_KHR_cooperative_matrix", 31>; +def SPV_KHR_bfloat16 : I32EnumAttrCase<"SPV_KHR_bfloat16", 32>; def SPV_EXT_demote_to_helper_invocation : I32EnumAttrCase<"SPV_EXT_demote_to_helper_invocation", 1000>; def SPV_EXT_descriptor_indexing : I32EnumAttrCase<"SPV_EXT_descriptor_indexing", 1001>; @@ -436,7 +437,7 @@ def SPIRV_ExtensionAttr : SPV_KHR_fragment_shader_barycentric, SPV_KHR_ray_cull_mask, SPV_KHR_uniform_group_instructions, SPV_KHR_subgroup_rotate, SPV_KHR_non_semantic_info, SPV_KHR_terminate_invocation, - SPV_KHR_cooperative_matrix, + SPV_KHR_cooperative_matrix, SPV_KHR_bfloat16, SPV_EXT_demote_to_helper_invocation, SPV_EXT_descriptor_indexing, SPV_EXT_fragment_fully_covered, SPV_EXT_fragment_invocation_density, SPV_EXT_fragment_shader_interlock, SPV_EXT_physical_storage_buffer, @@ -1412,6 +1413,23 @@ def SPIRV_C_ShaderStereoViewNV : I32EnumAttrCase<"Shade Extension<[SPV_NV_stereo_view_rendering]> ]; } +def SPIRV_C_BFloat16TypeKHR : I32EnumAttrCase<"BFloat16TypeKHR", 5116> { + list availability = [ + Extension<[SPV_KHR_bfloat16]> + ]; +} +def SPIRV_C_BFloat16DotProductKHR : I32EnumAttrCase<"BFloat16DotProductKHR", 5117> { + list implies = [SPIRV_C_BFloat16TypeKHR]; + list availability = [ + Extension<[SPV_KHR_bfloat16]> + ]; +} +def SPIRV_C_BFloat16CooperativeMatrixKHR : I32EnumAttrCase<"BFloat16CooperativeMatrixKHR", 5118> { + list implies = [SPIRV_C_BFloat16TypeKHR, SPIRV_C_CooperativeMatrixKHR]; + list availability = [ + Extension<[SPV_KHR_bfloat16]> + ]; +} def SPIRV_C_Bfloat16ConversionINTEL : I32EnumAttrCase<"Bfloat16ConversionINTEL", 6115> { list availability = [ @@ -1518,7 +1536,8 @@ def SPIRV_CapabilityAttr : SPIRV_C_StorageTexelBufferArrayNonUniformIndexing, SPIRV_C_ShaderViewportIndexLayerEXT, SPIRV_C_ShaderViewportMaskNV, SPIRV_C_ShaderStereoViewNV, SPIRV_C_Bfloat16ConversionINTEL, - SPIRV_C_CacheControlsINTEL + SPIRV_C_CacheControlsINTEL, SPIRV_C_BFloat16TypeKHR, + SPIRV_C_BFloat16DotProductKHR, SPIRV_C_BFloat16CooperativeMatrixKHR ]>; def SPIRV_AM_Logical : I32EnumAttrCase<"Logical", 0>; @@ -3217,6 +3236,16 @@ def SPIRV_ExecutionModelAttr : SPIRV_EM_TaskEXT, SPIRV_EM_MeshEXT ]>; +def SPIRV_FPE_BFloat16KHR : I32EnumAttrCase<"BFloat16KHR", 0> { + list availability = [ + Capability<[SPIRV_C_BFloat16TypeKHR]> + ]; +} +def SPIRV_FPEncodingAttr : + SPIRV_I32EnumAttr<"FPEncoding", "valid SPIR-V FPEncoding", "f_p_encoding", [ + SPIRV_FPE_BFloat16KHR + ]>; + def SPIRV_FC_None : I32BitEnumAttrCaseNone<"None">; def SPIRV_FC_Inline : I32BitEnumAttrCaseBit<"Inline", 0>; def SPIRV_FC_DontInline : I32BitEnumAttrCaseBit<"DontInline", 1>; @@ -4161,10 +4190,12 @@ def SPIRV_Integer : AnyIntOfWidths<[8, 16, 32, 64]>; def SPIRV_Int16 : TypeAlias; def SPIRV_Int32 : TypeAlias; def SPIRV_Float32 : TypeAlias; +def SPIRV_BFloat16KHR : TypeAlias; def SPIRV_Float : FloatOfWidths<[16, 32, 64]>; def SPIRV_Float16or32 : FloatOfWidths<[16, 32]>; +def SPIRV_AnyFloat : AnyTypeOf<[SPIRV_Float, SPIRV_BFloat16KHR]>; def SPIRV_Vector : VectorOfLengthAndType<[2, 3, 4, 8, 16], - [SPIRV_Bool, SPIRV_Integer, SPIRV_Float]>; + [SPIRV_Bool, SPIRV_Integer, SPIRV_AnyFloat]>; // Component type check is done in the type parser for the following SPIR-V // dialect-specific types so we use "Any" here. def SPIRV_AnyPtr : DialectType; -def SPIRV_Numerical : AnyTypeOf<[SPIRV_Integer, SPIRV_Float]>; +def SPIRV_Numerical : AnyTypeOf<[SPIRV_Integer, SPIRV_AnyFloat]>; def SPIRV_Scalar : AnyTypeOf<[SPIRV_Numerical, SPIRV_Bool]>; def SPIRV_Aggregate : AnyTypeOf<[SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct]>; def SPIRV_Composite : AnyTypeOf<[SPIRV_Vector, SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct, SPIRV_AnyCooperativeMatrix, SPIRV_AnyMatrix]>; def SPIRV_Type : AnyTypeOf<[ - SPIRV_Void, SPIRV_Bool, SPIRV_Integer, SPIRV_Float, SPIRV_Vector, + SPIRV_Void, SPIRV_Bool, SPIRV_Integer, SPIRV_AnyFloat, SPIRV_Vector, SPIRV_AnyPtr, SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct, SPIRV_AnyCooperativeMatrix, SPIRV_AnyMatrix, SPIRV_AnySampledImage ]>; diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td index b05ee0251df5b..a5c8aa8fb450c 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td @@ -86,7 +86,7 @@ def SPIRV_BitcastOp : SPIRV_Op<"Bitcast", [Pure]> { // ----- -def SPIRV_ConvertFToSOp : SPIRV_CastOp<"ConvertFToS", SPIRV_Integer, SPIRV_Float, []> { +def SPIRV_ConvertFToSOp : SPIRV_CastOp<"ConvertFToS", SPIRV_Integer, SPIRV_AnyFloat, []> { let summary = [{ Convert value numerically from floating point to signed integer, with round toward 0.0. @@ -111,7 +111,7 @@ def SPIRV_ConvertFToSOp : SPIRV_CastOp<"ConvertFToS", SPIRV_Integer, SPIRV_Float // ----- -def SPIRV_ConvertFToUOp : SPIRV_CastOp<"ConvertFToU", SPIRV_Integer, SPIRV_Float, []> { +def SPIRV_ConvertFToUOp : SPIRV_CastOp<"ConvertFToU", SPIRV_Integer, SPIRV_AnyFloat, []> { let summary = [{ Convert value numerically from floating point to unsigned integer, with round toward 0.0. @@ -138,7 +138,7 @@ def SPIRV_ConvertFToUOp : SPIRV_CastOp<"ConvertFToU", SPIRV_Integer, SPIRV_Float // ----- def SPIRV_ConvertSToFOp : SPIRV_CastOp<"ConvertSToF", - SPIRV_Float, + SPIRV_AnyFloat, SPIRV_Integer, [SignedOp]> { let summary = [{ @@ -165,7 +165,7 @@ def SPIRV_ConvertSToFOp : SPIRV_CastOp<"ConvertSToF", // ----- def SPIRV_ConvertUToFOp : SPIRV_CastOp<"ConvertUToF", - SPIRV_Float, + SPIRV_AnyFloat, SPIRV_Integer, [UnsignedOp]> { let summary = [{ @@ -192,8 +192,8 @@ def SPIRV_ConvertUToFOp : SPIRV_CastOp<"ConvertUToF", // ----- def SPIRV_FConvertOp : SPIRV_CastOp<"FConvert", - SPIRV_Float, - SPIRV_Float, + SPIRV_AnyFloat, + SPIRV_AnyFloat, [UsableInSpecConstantOp]> { let summary = [{ Convert value numerically from one floating-point width to another diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp index 0cf5f0823be63..a21acef1c4b43 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp @@ -175,10 +175,7 @@ static Type parseAndVerifyType(SPIRVDialect const &dialect, // Check other allowed types if (auto t = llvm::dyn_cast(type)) { - if (type.isBF16()) { - parser.emitError(typeLoc, "cannot use 'bf16' to compose SPIR-V types"); - return Type(); - } + // TODO: All float types are allowed for now, but this should be fixed. } else if (auto t = llvm::dyn_cast(type)) { if (!ScalarType::isValid(t)) { parser.emitError(typeLoc, diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp index 337df3a5a65f0..1e71f4277f660 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp @@ -505,7 +505,7 @@ bool ScalarType::classof(Type type) { } bool ScalarType::isValid(FloatType type) { - return llvm::is_contained({16u, 32u, 64u}, type.getWidth()) && !type.isBF16(); + return llvm::is_contained({16u, 32u, 64u}, type.getWidth()); } bool ScalarType::isValid(IntegerType type) { @@ -514,6 +514,11 @@ bool ScalarType::isValid(IntegerType type) { void ScalarType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional storage) { + if (isa(*this)) { + static const Extension ext = Extension::SPV_KHR_bfloat16; + extensions.push_back(ext); + } + // 8- or 16-bit integer/floating-point numbers will require extra extensions // to appear in interface storage classes. See SPV_KHR_16bit_storage and // SPV_KHR_8bit_storage for more details. @@ -619,7 +624,16 @@ void ScalarType::getCapabilities( } else { assert(llvm::isa(*this)); switch (bitwidth) { - WIDTH_CASE(Float, 16); + case 16: { + if (isa(*this)) { + static const Capability cap = Capability::BFloat16TypeKHR; + capabilities.push_back(cap); + } else { + static const Capability cap = Capability::Float16; + capabilities.push_back(cap); + } + break; + } WIDTH_CASE(Float, 64); case 32: break; diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp index 3957dbc0db984..cfd3e19bca45f 100644 --- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp +++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp @@ -867,11 +867,15 @@ LogicalResult spirv::Deserializer::processType(spirv::Opcode opcode, typeMap[operands[0]] = IntegerType::get(context, operands[1], sign); } break; case spirv::Opcode::OpTypeFloat: { - if (operands.size() != 2) - return emitError(unknownLoc, "OpTypeFloat must have bitwidth parameter"); + if (operands.size() != 2 && operands.size() != 3) + return emitError(unknownLoc, + "OpTypeFloat expects either 2 operands (type, bitwidth) " + "or 3 operands (type, bitwidth, encoding), but got ") + << operands.size(); + uint32_t bitWidth = operands[1]; Type floatTy; - switch (operands[1]) { + switch (bitWidth) { case 16: floatTy = opBuilder.getF16Type(); break; @@ -883,8 +887,20 @@ LogicalResult spirv::Deserializer::processType(spirv::Opcode opcode, break; default: return emitError(unknownLoc, "unsupported OpTypeFloat bitwidth: ") - << operands[1]; + << bitWidth; + } + + if (operands.size() == 3) { + if (spirv::FPEncoding(operands[2]) != spirv::FPEncoding::BFloat16KHR) + return emitError(unknownLoc, "unsupported OpTypeFloat FP encoding: ") + << operands[2]; + if (bitWidth != 16) + return emitError(unknownLoc, + "invalid OpTypeFloat bitwidth for bfloat16 encoding: ") + << bitWidth << " (expected 16)"; + floatTy = opBuilder.getBF16Type(); } + typeMap[operands[0]] = floatTy; } break; case spirv::Opcode::OpTypeVector: { @@ -1399,6 +1415,9 @@ LogicalResult spirv::Deserializer::processConstant(ArrayRef operands, } else if (floatType.isF16()) { APInt data(16, operands[2]); value = APFloat(APFloat::IEEEhalf(), data); + } else if (floatType.isBF16()) { + APInt data(16, operands[2]); + value = APFloat(APFloat::BFloat(), data); } auto attr = opBuilder.getFloatAttr(floatType, value); diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp index 15e06616f4492..5de498cb454a7 100644 --- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp @@ -523,6 +523,9 @@ LogicalResult Serializer::prepareBasicType( if (auto floatType = dyn_cast(type)) { typeEnum = spirv::Opcode::OpTypeFloat; operands.push_back(floatType.getWidth()); + if (floatType.isBF16()) { + operands.push_back(static_cast(spirv::FPEncoding::BFloat16KHR)); + } return success(); } @@ -996,21 +999,23 @@ uint32_t Serializer::prepareConstantFp(Location loc, FloatAttr floatAttr, auto resultID = getNextID(); APFloat value = floatAttr.getValue(); + const llvm::fltSemantics *semantics = &value.getSemantics(); auto opcode = isSpec ? spirv::Opcode::OpSpecConstant : spirv::Opcode::OpConstant; - if (&value.getSemantics() == &APFloat::IEEEsingle()) { + if (semantics == &APFloat::IEEEsingle()) { uint32_t word = llvm::bit_cast(value.convertToFloat()); encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, word}); - } else if (&value.getSemantics() == &APFloat::IEEEdouble()) { + } else if (semantics == &APFloat::IEEEdouble()) { struct DoubleWord { uint32_t word1; uint32_t word2; } words = llvm::bit_cast(value.convertToDouble()); encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, words.word1, words.word2}); - } else if (&value.getSemantics() == &APFloat::IEEEhalf()) { + } else if (semantics == &APFloat::IEEEhalf() || + semantics == &APFloat::BFloat()) { uint32_t word = static_cast(value.bitcastToAPInt().getZExtValue()); encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, word}); diff --git a/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir b/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir index 82d750755ffe2..1737f4a906bf8 100644 --- a/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir +++ b/mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir @@ -173,6 +173,12 @@ func.func @float16(%arg0: f16) { return } // NOEMU-SAME: f64 func.func @float64(%arg0: f64) { return } +// CHECK-LABEL: spirv.func @bfloat16 +// CHECK-SAME: f32 +// NOEMU-LABEL: func.func @bfloat16 +// NOEMU-SAME: bf16 +func.func @bfloat16(%arg0: bf16) { return } + // f80 is not supported by SPIR-V. // CHECK-LABEL: func.func @float80 // CHECK-SAME: f80 @@ -206,18 +212,6 @@ func.func @float64(%arg0: f64) { return } // ----- -// Check that bf16 is not supported. -module attributes { - spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> -} { - -// CHECK-NOT: spirv.func @bf16_type -func.func @bf16_type(%arg0: bf16) { return } - -} // end module - -// ----- - //===----------------------------------------------------------------------===// // Complex types //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir b/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir index 2d0c86e08de5a..d58c27598f2b8 100644 --- a/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir @@ -12,6 +12,14 @@ func.func @fadd_scalar(%arg: f32) -> f32 { // ----- +func.func @fadd_bf16_scalar(%arg: bf16) -> bf16 { + // expected-error @+1 {{operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}} + %0 = spirv.FAdd %arg, %arg : bf16 + return %0 : bf16 +} + +// ----- + //===----------------------------------------------------------------------===// // spirv.FDiv //===----------------------------------------------------------------------===// @@ -24,6 +32,14 @@ func.func @fdiv_scalar(%arg: f32) -> f32 { // ----- +func.func @fdiv_bf16_scalar(%arg: bf16) -> bf16 { + // expected-error @+1 {{operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}} + %0 = spirv.FDiv %arg, %arg : bf16 + return %0 : bf16 +} + +// ----- + //===----------------------------------------------------------------------===// // spirv.FMod //===----------------------------------------------------------------------===// @@ -36,6 +52,14 @@ func.func @fmod_scalar(%arg: f32) -> f32 { // ----- +func.func @fmod_bf16_scalar(%arg: bf16) -> bf16 { + // expected-error @+1 {{operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}} + %0 = spirv.FMod %arg, %arg : bf16 + return %0 : bf16 +} + +// ----- + //===----------------------------------------------------------------------===// // spirv.FMul //===----------------------------------------------------------------------===// @@ -70,6 +94,14 @@ func.func @fmul_bf16(%arg: bf16) -> bf16 { // ----- +func.func @fmul_bf16_vector(%arg: vector<4xbf16>) -> vector<4xbf16> { + // expected-error @+1 {{operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}} + %0 = spirv.FMul %arg, %arg : vector<4xbf16> + return %0 : vector<4xbf16> +} + +// ----- + func.func @fmul_tensor(%arg: tensor<4xf32>) -> tensor<4xf32> { // expected-error @+1 {{operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}} %0 = spirv.FMul %arg, %arg : tensor<4xf32> @@ -90,6 +122,14 @@ func.func @fnegate_scalar(%arg: f32) -> f32 { // ----- +func.func @fnegate_bf16_scalar(%arg: bf16) -> bf16 { + // expected-error @+1 {{operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}} + %0 = spirv.FNegate %arg : bf16 + return %0 : bf16 +} + +// ----- + //===----------------------------------------------------------------------===// // spirv.FRem //===----------------------------------------------------------------------===// @@ -102,6 +142,14 @@ func.func @frem_scalar(%arg: f32) -> f32 { // ----- +func.func @frem_bf16_scalar(%arg: bf16) -> bf16 { + // expected-error @+1 {{operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}} + %0 = spirv.FRem %arg, %arg : bf16 + return %0 : bf16 +} + +// ----- + //===----------------------------------------------------------------------===// // spirv.FSub //===----------------------------------------------------------------------===// @@ -114,6 +162,14 @@ func.func @fsub_scalar(%arg: f32) -> f32 { // ----- +func.func @fsub_bf16_scalar(%arg: bf16) -> bf16 { + // expected-error @+1 {{operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}} + %0 = spirv.FSub %arg, %arg : bf16 + return %0 : bf16 +} + +// ----- + //===----------------------------------------------------------------------===// // spirv.IAdd //===----------------------------------------------------------------------===// @@ -489,3 +545,11 @@ func.func @vector_times_scalar(%vector: vector<4xf32>, %scalar: f32) -> vector<3 %0 = spirv.VectorTimesScalar %vector, %scalar : (vector<4xf32>, f32) -> vector<3xf32> return %0 : vector<3xf32> } + +// ----- + +func.func @vector_bf16_times_scalar_bf16(%vector: vector<4xbf16>, %scalar: bf16) -> vector<4xbf16> { + // expected-error @+1 {{op operand #0 must be vector of 16/32/64-bit float values of length 2/3/4}} + %0 = spirv.VectorTimesScalar %vector, %scalar : (vector<4xbf16>, bf16) -> vector<4xbf16> + return %0 : vector<4xbf16> +} diff --git a/mlir/test/Dialect/SPIRV/IR/atomic-ops.mlir b/mlir/test/Dialect/SPIRV/IR/atomic-ops.mlir index cc0abd3a42dcb..661497d5fff38 100644 --- a/mlir/test/Dialect/SPIRV/IR/atomic-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/atomic-ops.mlir @@ -272,3 +272,11 @@ func.func @atomic_fadd(%ptr : !spirv.ptr, %value : f32) -> f %0 = spirv.EXT.AtomicFAdd %ptr, %value : !spirv.ptr return %0 : f32 } + +// ----- + +func.func @atomic_bf16_fadd(%ptr : !spirv.ptr, %value : bf16) -> bf16 { + // expected-error @+1 {{op operand #1 must be 16/32/64-bit float, but got 'bf16'}} + %0 = spirv.EXT.AtomicFAdd %ptr, %value : !spirv.ptr + return %0 : bf16 +} diff --git a/mlir/test/Dialect/SPIRV/IR/cast-ops.mlir b/mlir/test/Dialect/SPIRV/IR/cast-ops.mlir index 34d0109e6bb44..4480a1f3720f2 100644 --- a/mlir/test/Dialect/SPIRV/IR/cast-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/cast-ops.mlir @@ -110,6 +110,14 @@ func.func @convert_f_to_s_vector(%arg0 : vector<3xf32>) -> vector<3xi32> { // ----- +func.func @convert_bf16_to_s32_scalar(%arg0 : bf16) -> i32 { + // CHECK: {{%.*}} = spirv.ConvertFToS {{%.*}} : bf16 to i32 + %0 = spirv.ConvertFToS %arg0 : bf16 to i32 + spirv.ReturnValue %0 : i32 +} + +// ----- + //===----------------------------------------------------------------------===// // spirv.ConvertFToU //===----------------------------------------------------------------------===// @@ -146,6 +154,14 @@ func.func @convert_f_to_u.coopmatrix(%arg0 : !spirv.coopmatrix<8x16xf32, Subgrou // ----- +func.func @convert_bf16_to_u32_scalar(%arg0 : bf16) -> i32 { + // CHECK: {{%.*}} = spirv.ConvertFToU {{%.*}} : bf16 to i32 + %0 = spirv.ConvertFToU %arg0 : bf16 to i32 + spirv.ReturnValue %0 : i32 +} + +// ----- + //===----------------------------------------------------------------------===// // spirv.ConvertSToF //===----------------------------------------------------------------------===// @@ -174,6 +190,14 @@ func.func @convert_s_to_f_vector(%arg0 : vector<3xi32>) -> vector<3xf32> { // ----- +func.func @convert_s32_to_bf16_scalar(%arg0 : i32) -> bf16 { + // CHECK: {{%.*}} = spirv.ConvertSToF {{%.*}} : i32 to bf16 + %0 = spirv.ConvertSToF %arg0 : i32 to bf16 + spirv.ReturnValue %0 : bf16 +} + +// ----- + //===----------------------------------------------------------------------===// // spirv.ConvertUToF //===----------------------------------------------------------------------===// @@ -202,6 +226,14 @@ func.func @convert_u_to_f_vector(%arg0 : vector<3xi32>) -> vector<3xf32> { // ----- +func.func @convert_u32_to_bf16_scalar(%arg0 : i32) -> bf16 { + // CHECK: {{%.*}} = spirv.ConvertUToF {{%.*}} : i32 to bf16 + %0 = spirv.ConvertUToF %arg0 : i32 to bf16 + spirv.ReturnValue %0 : bf16 +} + +// ----- + //===----------------------------------------------------------------------===// // spirv.FConvert //===----------------------------------------------------------------------===// @@ -238,6 +270,30 @@ func.func @f_convert_vector(%arg0 : f32) -> f32 { // ----- +func.func @f_convert_bf16_to_f32_scalar(%arg0 : bf16) -> f32 { + // CHECK: {{%.*}} = spirv.FConvert {{%.*}} : bf16 to f32 + %0 = spirv.FConvert %arg0 : bf16 to f32 + spirv.ReturnValue %0 : f32 +} + +// ----- + +func.func @f_convert_f32_to_bf16_vector(%arg0 : vector<3xf32>) -> vector<3xbf16> { + // CHECK: {{%.*}} = spirv.FConvert {{%.*}} : vector<3xf32> to vector<3xbf16> + %0 = spirv.FConvert %arg0 : vector<3xf32> to vector<3xbf16> + spirv.ReturnValue %0 : vector<3xbf16> +} + +// ----- + +func.func @f_convert_f32_to_bf16_coop_matrix(%arg0 : !spirv.coopmatrix<8x16xf32, Subgroup, MatrixA>) -> !spirv.coopmatrix<8x16xbf16, Subgroup, MatrixA> { + // CHECK: {{%.*}} = spirv.FConvert {{%.*}} : !spirv.coopmatrix<8x16xf32, Subgroup, MatrixA> to !spirv.coopmatrix<8x16xbf16, Subgroup, MatrixA> + %0 = spirv.FConvert %arg0 : !spirv.coopmatrix<8x16xf32, Subgroup, MatrixA> to !spirv.coopmatrix<8x16xbf16, Subgroup, MatrixA> + spirv.ReturnValue %0 : !spirv.coopmatrix<8x16xbf16, Subgroup, MatrixA> +} + +// ----- + //===----------------------------------------------------------------------===// // spirv.SConvert //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir b/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir index 3fc8dfb2767d1..e71b545de11df 100644 --- a/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir @@ -11,6 +11,13 @@ func.func @composite_construct_vector(%arg0: f32, %arg1: f32, %arg2 : f32) -> ve return %0: vector<3xf32> } +// CHECK-LABEL: func @composite_construct_bf16_vector +func.func @composite_construct_bf16_vector(%arg0: bf16, %arg1: bf16, %arg2 : bf16) -> vector<3xbf16> { + // CHECK: spirv.CompositeConstruct {{%.*}}, {{%.*}}, {{%.*}} : (bf16, bf16, bf16) -> vector<3xbf16> + %0 = spirv.CompositeConstruct %arg0, %arg1, %arg2 : (bf16, bf16, bf16) -> vector<3xbf16> + return %0: vector<3xbf16> +} + // CHECK-LABEL: func @composite_construct_struct func.func @composite_construct_struct(%arg0: vector<3xf32>, %arg1: !spirv.array<4xf32>, %arg2 : !spirv.struct<(f32)>) -> !spirv.struct<(vector<3xf32>, !spirv.array<4xf32>, !spirv.struct<(f32)>)> { // CHECK: spirv.CompositeConstruct diff --git a/mlir/test/Dialect/SPIRV/IR/gl-ops.mlir b/mlir/test/Dialect/SPIRV/IR/gl-ops.mlir index 0be047932c1f3..4c141a285cd30 100644 --- a/mlir/test/Dialect/SPIRV/IR/gl-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/gl-ops.mlir @@ -50,6 +50,14 @@ func.func @exp(%arg0 : i32) -> () { // ----- +func.func @exp_bf16(%arg0 : bf16) -> () { + // expected-error @+1 {{op operand #0 must be 16/32-bit float or vector of 16/32-bit float values of length 2/3/4}} + %2 = spirv.GL.Exp %arg0 : bf16 + return +} + +// ----- + //===----------------------------------------------------------------------===// // spirv.GL.{F|S|U}{Max|Min} //===----------------------------------------------------------------------===// @@ -92,6 +100,15 @@ func.func @iminmax(%arg0: i32, %arg1: i32) { // ----- +func.func @fmaxminbf16vec(%arg0 : vector<3xbf16>, %arg1 : vector<3xbf16>) { + // expected-error @+1 {{operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}} + %1 = spirv.GL.FMax %arg0, %arg1 : vector<3xbf16> + %2 = spirv.GL.FMin %arg0, %arg1 : vector<3xbf16> + return +} + +// ----- + //===----------------------------------------------------------------------===// // spirv.GL.InverseSqrt //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir b/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir index 5c24f0e6a7d33..d6c34645f5746 100644 --- a/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/logical-ops.mlir @@ -201,6 +201,14 @@ func.func @select_op_float(%arg0: i1) -> () { return } +func.func @select_op_bfloat16(%arg0: i1) -> () { + %0 = spirv.Constant 2.0 : bf16 + %1 = spirv.Constant 3.0 : bf16 + // CHECK: spirv.Select {{%.*}}, {{%.*}}, {{%.*}} : i1, bf16 + %2 = spirv.Select %arg0, %0, %1 : i1, bf16 + return +} + func.func @select_op_ptr(%arg0: i1) -> () { %0 = spirv.Variable : !spirv.ptr %1 = spirv.Variable : !spirv.ptr diff --git a/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir b/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir index 5f56de6ad1fa9..7ab94f17360d5 100644 --- a/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir @@ -184,6 +184,14 @@ func.func @group_non_uniform_fmul_clustered_reduce(%val: vector<2xf32>) -> vecto // ----- +func.func @group_non_uniform_bf16_fmul_reduce(%val: bf16) -> bf16 { + // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values of length 2/3/4/8/16, but got 'bf16'}} + %0 = spirv.GroupNonUniformFMul %val : bf16 -> bf16 + return %0: bf16 +} + +// ----- + //===----------------------------------------------------------------------===// // spirv.GroupNonUniformFMax //===----------------------------------------------------------------------===// @@ -197,6 +205,14 @@ func.func @group_non_uniform_fmax_reduce(%val: f32) -> f32 { // ----- +func.func @group_non_uniform_bf16_fmax_reduce(%val: bf16) -> bf16 { + // expected-error @+1 {{op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values of length 2/3/4/8/16, but got 'bf16'}} + %0 = spirv.GroupNonUniformFMax %val : bf16 -> bf16 + return %0: bf16 +} + +// ----- + //===----------------------------------------------------------------------===// // spirv.GroupNonUniformFMin //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SPIRV/IR/types.mlir b/mlir/test/Dialect/SPIRV/IR/types.mlir index b63a08d96e6af..c23894c62826b 100644 --- a/mlir/test/Dialect/SPIRV/IR/types.mlir +++ b/mlir/test/Dialect/SPIRV/IR/types.mlir @@ -15,6 +15,9 @@ func.func private @vector_array_type(!spirv.array< 32 x vector<4xf32> >) -> () // CHECK: func private @array_type_stride(!spirv.array<4 x !spirv.array<4 x f32, stride=4>, stride=128>) func.func private @array_type_stride(!spirv.array< 4 x !spirv.array<4 x f32, stride=4>, stride = 128>) -> () +// CHECK: func private @vector_array_type_bf16(!spirv.array<32 x vector<4xbf16>>) +func.func private @vector_array_type_bf16(!spirv.array<32 x vector<4xbf16> >) -> () + // ----- // expected-error @+1 {{expected '<'}} @@ -57,11 +60,6 @@ func.func private @tensor_type(!spirv.array<4xtensor<4xf32>>) -> () // ----- -// expected-error @+1 {{cannot use 'bf16' to compose SPIR-V types}} -func.func private @bf16_type(!spirv.array<4xbf16>) -> () - -// ----- - // expected-error @+1 {{only 1/8/16/32/64-bit integer type allowed but found 'i256'}} func.func private @i256_type(!spirv.array<4xi256>) -> () diff --git a/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir b/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir index ff5ac7cea8fc6..2b237665ffc4a 100644 --- a/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir +++ b/mlir/test/Dialect/SPIRV/Transforms/vce-deduction.mlir @@ -217,3 +217,17 @@ spirv.module Logical GLSL450 attributes { spirv.GlobalVariable @data : !spirv.ptr, Uniform> spirv.GlobalVariable @img : !spirv.ptr, UniformConstant> } + +// Using bfloat16 requires BFloat16TypeKHR capability and SPV_KHR_bfloat16 extension. +// CHECK: requires #spirv.vce +spirv.module Logical GLSL450 attributes { + spirv.target_env = #spirv.target_env< + #spirv.vce, + #spirv.resource_limits<> + > +} { + spirv.func @load_bf16(%ptr : !spirv.ptr) -> bf16 "None" { + %val = spirv.Load "StorageBuffer" %ptr : bf16 + spirv.ReturnValue %val : bf16 + } +} diff --git a/mlir/test/Target/SPIRV/cast-ops.mlir b/mlir/test/Target/SPIRV/cast-ops.mlir index ede0bf30511ef..04a468b39b645 100644 --- a/mlir/test/Target/SPIRV/cast-ops.mlir +++ b/mlir/test/Target/SPIRV/cast-ops.mlir @@ -25,6 +25,11 @@ spirv.module Logical GLSL450 requires #spirv.vce { %0 = spirv.ConvertFToS %arg0 : f64 to i32 spirv.ReturnValue %0 : i32 } + spirv.func @convert_bf16_to_s32(%arg0 : bf16) -> i32 "None" { + // CHECK: {{%.*}} = spirv.ConvertFToS {{%.*}} : bf16 to i32 + %0 = spirv.ConvertFToS %arg0 : bf16 to i32 + spirv.ReturnValue %0 : i32 + } spirv.func @convert_f_to_u(%arg0 : f32) -> i32 "None" { // CHECK: {{%.*}} = spirv.ConvertFToU {{%.*}} : f32 to i32 %0 = spirv.ConvertFToU %arg0 : f32 to i32 @@ -35,6 +40,11 @@ spirv.module Logical GLSL450 requires #spirv.vce { %0 = spirv.ConvertFToU %arg0 : f64 to i32 spirv.ReturnValue %0 : i32 } + spirv.func @convert_bf16_to_u32(%arg0 : bf16) -> i32 "None" { + // CHECK: {{%.*}} = spirv.ConvertFToU {{%.*}} : bf16 to i32 + %0 = spirv.ConvertFToU %arg0 : bf16 to i32 + spirv.ReturnValue %0 : i32 + } spirv.func @convert_s_to_f(%arg0 : i32) -> f32 "None" { // CHECK: {{%.*}} = spirv.ConvertSToF {{%.*}} : i32 to f32 %0 = spirv.ConvertSToF %arg0 : i32 to f32 @@ -45,6 +55,11 @@ spirv.module Logical GLSL450 requires #spirv.vce { %0 = spirv.ConvertSToF %arg0 : i64 to f32 spirv.ReturnValue %0 : f32 } + spirv.func @convert_s64_to_bf16(%arg0 : i64) -> bf16 "None" { + // CHECK: {{%.*}} = spirv.ConvertSToF {{%.*}} : i64 to bf16 + %0 = spirv.ConvertSToF %arg0 : i64 to bf16 + spirv.ReturnValue %0 : bf16 + } spirv.func @convert_u_to_f(%arg0 : i32) -> f32 "None" { // CHECK: {{%.*}} = spirv.ConvertUToF {{%.*}} : i32 to f32 %0 = spirv.ConvertUToF %arg0 : i32 to f32 @@ -55,11 +70,26 @@ spirv.module Logical GLSL450 requires #spirv.vce { %0 = spirv.ConvertUToF %arg0 : i64 to f32 spirv.ReturnValue %0 : f32 } - spirv.func @f_convert(%arg0 : f32) -> f64 "None" { + spirv.func @convert_u64_to_bf16(%arg0 : i64) -> bf16 "None" { + // CHECK: {{%.*}} = spirv.ConvertUToF {{%.*}} : i64 to bf16 + %0 = spirv.ConvertUToF %arg0 : i64 to bf16 + spirv.ReturnValue %0 : bf16 + } + spirv.func @convert_f32_to_f64(%arg0 : f32) -> f64 "None" { // CHECK: {{%.*}} = spirv.FConvert {{%.*}} : f32 to f64 %0 = spirv.FConvert %arg0 : f32 to f64 spirv.ReturnValue %0 : f64 } + spirv.func @convert_f32_to_bf16(%arg0 : f32) -> bf16 "None" { + // CHECK: {{%.*}} = spirv.FConvert {{%.*}} : f32 to bf16 + %0 = spirv.FConvert %arg0 : f32 to bf16 + spirv.ReturnValue %0 : bf16 + } + spirv.func @convert_bf16_to_f32(%arg0 : bf16) -> f32 "None" { + // CHECK: {{%.*}} = spirv.FConvert {{%.*}} : bf16 to f32 + %0 = spirv.FConvert %arg0 : bf16 to f32 + spirv.ReturnValue %0 : f32 + } spirv.func @s_convert(%arg0 : i32) -> i64 "None" { // CHECK: {{%.*}} = spirv.SConvert {{%.*}} : i32 to i64 %0 = spirv.SConvert %arg0 : i32 to i64 diff --git a/mlir/test/Target/SPIRV/logical-ops.mlir b/mlir/test/Target/SPIRV/logical-ops.mlir index 16846ac84e38c..b2008719b021c 100644 --- a/mlir/test/Target/SPIRV/logical-ops.mlir +++ b/mlir/test/Target/SPIRV/logical-ops.mlir @@ -108,3 +108,26 @@ spirv.module Logical GLSL450 requires #spirv.vce { spirv.Return } } + +// ----- + +// Test select works with bf16 scalar and vectors. + +spirv.module Logical GLSL450 requires #spirv.vce { + spirv.SpecConstant @condition_scalar = true + spirv.func @select_bf16() -> () "None" { + %0 = spirv.Constant 4.0 : bf16 + %1 = spirv.Constant 5.0 : bf16 + %2 = spirv.mlir.referenceof @condition_scalar : i1 + // CHECK: spirv.Select {{.*}}, {{.*}}, {{.*}} : i1, bf16 + %3 = spirv.Select %2, %0, %1 : i1, bf16 + %4 = spirv.Constant dense<[2.0, 3.0, 4.0, 5.0]> : vector<4xbf16> + %5 = spirv.Constant dense<[6.0, 7.0, 8.0, 9.0]> : vector<4xbf16> + // CHECK: spirv.Select {{.*}}, {{.*}}, {{.*}} : i1, vector<4xbf16> + %6 = spirv.Select %2, %4, %5 : i1, vector<4xbf16> + %7 = spirv.Constant dense<[true, true, true, true]> : vector<4xi1> + // CHECK: spirv.Select {{.*}}, {{.*}}, {{.*}} : vector<4xi1>, vector<4xbf16> + %8 = spirv.Select %7, %4, %5 : vector<4xi1>, vector<4xbf16> + spirv.Return + } +}