Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 34 additions & 4 deletions mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,7 @@ def SPV_KHR_subgroup_rotate : I32EnumAttrCase<"SPV_KHR_subgroup
def SPV_KHR_non_semantic_info : I32EnumAttrCase<"SPV_KHR_non_semantic_info", 29>;
def SPV_KHR_terminate_invocation : I32EnumAttrCase<"SPV_KHR_terminate_invocation", 30>;
def SPV_KHR_cooperative_matrix : I32EnumAttrCase<"SPV_KHR_cooperative_matrix", 31>;
def SPV_KHR_bfloat16 : I32EnumAttrCase<"SPV_KHR_bfloat16", 32>;

def SPV_EXT_demote_to_helper_invocation : I32EnumAttrCase<"SPV_EXT_demote_to_helper_invocation", 1000>;
def SPV_EXT_descriptor_indexing : I32EnumAttrCase<"SPV_EXT_descriptor_indexing", 1001>;
Expand Down Expand Up @@ -436,7 +437,7 @@ def SPIRV_ExtensionAttr :
SPV_KHR_fragment_shader_barycentric, SPV_KHR_ray_cull_mask,
SPV_KHR_uniform_group_instructions, SPV_KHR_subgroup_rotate,
SPV_KHR_non_semantic_info, SPV_KHR_terminate_invocation,
SPV_KHR_cooperative_matrix,
SPV_KHR_cooperative_matrix, SPV_KHR_bfloat16,
SPV_EXT_demote_to_helper_invocation, SPV_EXT_descriptor_indexing,
SPV_EXT_fragment_fully_covered, SPV_EXT_fragment_invocation_density,
SPV_EXT_fragment_shader_interlock, SPV_EXT_physical_storage_buffer,
Expand Down Expand Up @@ -1412,6 +1413,23 @@ def SPIRV_C_ShaderStereoViewNV : I32EnumAttrCase<"Shade
Extension<[SPV_NV_stereo_view_rendering]>
];
}
def SPIRV_C_BFloat16TypeKHR : I32EnumAttrCase<"BFloat16TypeKHR", 5116> {
list<Availability> availability = [
Extension<[SPV_KHR_bfloat16]>
];
}
def SPIRV_C_BFloat16DotProductKHR : I32EnumAttrCase<"BFloat16DotProductKHR", 5117> {
list<I32EnumAttrCase> implies = [SPIRV_C_BFloat16TypeKHR];
list<Availability> availability = [
Extension<[SPV_KHR_bfloat16]>
];
}
def SPIRV_C_BFloat16CooperativeMatrixKHR : I32EnumAttrCase<"BFloat16CooperativeMatrixKHR", 5118> {
list<I32EnumAttrCase> implies = [SPIRV_C_BFloat16TypeKHR, SPIRV_C_CooperativeMatrixKHR];
list<Availability> availability = [
Extension<[SPV_KHR_bfloat16]>
];
}

def SPIRV_C_Bfloat16ConversionINTEL : I32EnumAttrCase<"Bfloat16ConversionINTEL", 6115> {
list<Availability> availability = [
Expand Down Expand Up @@ -1518,7 +1536,8 @@ def SPIRV_CapabilityAttr :
SPIRV_C_StorageTexelBufferArrayNonUniformIndexing,
SPIRV_C_ShaderViewportIndexLayerEXT, SPIRV_C_ShaderViewportMaskNV,
SPIRV_C_ShaderStereoViewNV, SPIRV_C_Bfloat16ConversionINTEL,
SPIRV_C_CacheControlsINTEL
SPIRV_C_CacheControlsINTEL, SPIRV_C_BFloat16TypeKHR,
SPIRV_C_BFloat16DotProductKHR, SPIRV_C_BFloat16CooperativeMatrixKHR
]>;

def SPIRV_AM_Logical : I32EnumAttrCase<"Logical", 0>;
Expand Down Expand Up @@ -3217,6 +3236,16 @@ def SPIRV_ExecutionModelAttr :
SPIRV_EM_TaskEXT, SPIRV_EM_MeshEXT
]>;

def SPIRV_FPE_BFloat16KHR : I32EnumAttrCase<"BFloat16KHR", 0> {
list<Availability> availability = [
Capability<[SPIRV_C_BFloat16TypeKHR]>
];
}
def SPIRV_FPEncodingAttr :
SPIRV_I32EnumAttr<"FPEncoding", "valid SPIR-V FPEncoding", "f_p_encoding", [
SPIRV_FPE_BFloat16KHR
]>;

def SPIRV_FC_None : I32BitEnumAttrCaseNone<"None">;
def SPIRV_FC_Inline : I32BitEnumAttrCaseBit<"Inline", 0>;
def SPIRV_FC_DontInline : I32BitEnumAttrCaseBit<"DontInline", 1>;
Expand Down Expand Up @@ -4163,8 +4192,9 @@ def SPIRV_Int32 : TypeAlias<I32, "Int32">;
def SPIRV_Float32 : TypeAlias<F32, "Float32">;
def SPIRV_Float : FloatOfWidths<[16, 32, 64]>;
def SPIRV_Float16or32 : FloatOfWidths<[16, 32]>;
def SPIRV_FloatOrBFloat16 : AnyTypeOf<[SPIRV_Float, BF16]>;
def SPIRV_Vector : VectorOfLengthAndType<[2, 3, 4, 8, 16],
[SPIRV_Bool, SPIRV_Integer, SPIRV_Float]>;
[SPIRV_Bool, SPIRV_Integer, SPIRV_FloatOrBFloat16]>;
Copy link
Contributor Author

@fairywreath fairywreath May 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can argue that bf16 should be part of SPIRV_Float. The problem here is that bf16 usage in SPIRV is very limited while SPIRV_Float(i.e regular floats) is used widely in the codebase for other ops(eg. texture sampling and regular arithmetic insts). I chose to leave SPIRV_Float to minimize the amount of changes(and to not introduce something like SPIRV_ArithmeticFloat). Please let me know if you think there is a cleaner solution to this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with this, it might be better to leave them alone. what about something like this:

def SPIRV_BFloat16KHR : TypeAlias<BF16, "BFloat16">;
def SPIRV_Float : AnyTypeOf<[F16, F32, F64]>;
def SPIRV_Float16or32 : AnyTypeOf<[F16, F32]>;
// Use this type for all kinds of floats.
def SPIRV_AnyFloat : AnyTypeOf<[SPIRV_BFloat16KHR, SPIRV_Float]>;
.....
def SPIRV_Vector : VectorOfLengthRangeAndType<[2, 0xFFFFFFFF],
                                       [SPIRV_Bool, SPIRV_Integer, SPIRV_Float, SPIRV_BFloat16KHR]>;
......

def SPIRV_Numerical : AnyTypeOf<[SPIRV_Integer, SPIRV_Float, SPIRV_BFloat16KHR]>;
.........

def SPIRV_Type : AnyTypeOf<[
    SPIRV_Void, SPIRV_Bool, SPIRV_Integer, SPIRV_Float, SPIRV_BFloat16KHR, SPIRV_Vector,
     SPIRV_AnyPtr, SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct,
     SPIRV_AnyCooperativeMatrix, SPIRV_AnyMatrix, SPIRV_AnySampledImage
   ]>;

// Component type check is done in the type parser for the following SPIR-V
// dialect-specific types so we use "Any" here.
def SPIRV_AnyPtr : DialectType<SPIRV_Dialect, SPIRV_IsPtrType,
Expand Down Expand Up @@ -4194,7 +4224,7 @@ def SPIRV_Composite :
AnyTypeOf<[SPIRV_Vector, SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct,
SPIRV_AnyCooperativeMatrix, SPIRV_AnyMatrix]>;
def SPIRV_Type : AnyTypeOf<[
SPIRV_Void, SPIRV_Bool, SPIRV_Integer, SPIRV_Float, SPIRV_Vector,
SPIRV_Void, SPIRV_Bool, SPIRV_Integer, SPIRV_FloatOrBFloat16, SPIRV_Vector,
SPIRV_AnyPtr, SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct,
SPIRV_AnyCooperativeMatrix, SPIRV_AnyMatrix, SPIRV_AnySampledImage
]>;
Expand Down
12 changes: 6 additions & 6 deletions mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def SPIRV_BitcastOp : SPIRV_Op<"Bitcast", [Pure]> {

// -----

def SPIRV_ConvertFToSOp : SPIRV_CastOp<"ConvertFToS", SPIRV_Integer, SPIRV_Float, []> {
def SPIRV_ConvertFToSOp : SPIRV_CastOp<"ConvertFToS", SPIRV_Integer, SPIRV_FloatOrBFloat16, []> {
let summary = [{
Convert value numerically from floating point to signed integer, with
round toward 0.0.
Expand All @@ -111,7 +111,7 @@ def SPIRV_ConvertFToSOp : SPIRV_CastOp<"ConvertFToS", SPIRV_Integer, SPIRV_Float

// -----

def SPIRV_ConvertFToUOp : SPIRV_CastOp<"ConvertFToU", SPIRV_Integer, SPIRV_Float, []> {
def SPIRV_ConvertFToUOp : SPIRV_CastOp<"ConvertFToU", SPIRV_Integer, SPIRV_FloatOrBFloat16, []> {
let summary = [{
Convert value numerically from floating point to unsigned integer, with
round toward 0.0.
Expand All @@ -138,7 +138,7 @@ def SPIRV_ConvertFToUOp : SPIRV_CastOp<"ConvertFToU", SPIRV_Integer, SPIRV_Float
// -----

def SPIRV_ConvertSToFOp : SPIRV_CastOp<"ConvertSToF",
SPIRV_Float,
SPIRV_FloatOrBFloat16,
SPIRV_Integer,
[SignedOp]> {
let summary = [{
Expand All @@ -165,7 +165,7 @@ def SPIRV_ConvertSToFOp : SPIRV_CastOp<"ConvertSToF",
// -----

def SPIRV_ConvertUToFOp : SPIRV_CastOp<"ConvertUToF",
SPIRV_Float,
SPIRV_FloatOrBFloat16,
SPIRV_Integer,
[UnsignedOp]> {
let summary = [{
Expand All @@ -192,8 +192,8 @@ def SPIRV_ConvertUToFOp : SPIRV_CastOp<"ConvertUToF",
// -----

def SPIRV_FConvertOp : SPIRV_CastOp<"FConvert",
SPIRV_Float,
SPIRV_Float,
SPIRV_FloatOrBFloat16,
SPIRV_FloatOrBFloat16,
[UsableInSpecConstantOp]> {
let summary = [{
Convert value numerically from one floating-point width to another
Expand Down
5 changes: 1 addition & 4 deletions mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -175,10 +175,7 @@ static Type parseAndVerifyType(SPIRVDialect const &dialect,

// Check other allowed types
if (auto t = llvm::dyn_cast<FloatType>(type)) {
if (type.isBF16()) {
parser.emitError(typeLoc, "cannot use 'bf16' to compose SPIR-V types");
return Type();
}
// TODO: All float types are allowed for now, but this should be fixed.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will address this in a separate PR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please elaborate what needs to be fixed here?

Copy link
Contributor Author

@fairywreath fairywreath May 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current behavior does not error out on bitwidths that are invalid for SPIRV (eg. F80, F128) and non-standard formats (eg. E3M2). Do you think it's better to address this here or in a separate PR?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my opinion it's okay to address it later. In fact I think it's preferable. Currently the code doesn't do any checks anyway, other than checking for bf16, so adding a proper check would be out of scope of this PR.

} else if (auto t = llvm::dyn_cast<IntegerType>(type)) {
if (!ScalarType::isValid(t)) {
parser.emitError(typeLoc,
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -505,7 +505,7 @@ bool ScalarType::classof(Type type) {
}

bool ScalarType::isValid(FloatType type) {
return llvm::is_contained({16u, 32u, 64u}, type.getWidth()) && !type.isBF16();
return llvm::is_contained({16u, 32u, 64u}, type.getWidth());
}

bool ScalarType::isValid(IntegerType type) {
Expand Down
21 changes: 18 additions & 3 deletions mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -866,11 +866,12 @@ LogicalResult spirv::Deserializer::processType(spirv::Opcode opcode,
typeMap[operands[0]] = IntegerType::get(context, operands[1], sign);
} break;
case spirv::Opcode::OpTypeFloat: {
if (operands.size() != 2)
if (operands.size() < 2)
return emitError(unknownLoc, "OpTypeFloat must have bitwidth parameter");
uint32_t bitWidth = operands[1];

Type floatTy;
switch (operands[1]) {
switch (bitWidth) {
case 16:
floatTy = opBuilder.getF16Type();
break;
Expand All @@ -882,8 +883,22 @@ LogicalResult spirv::Deserializer::processType(spirv::Opcode opcode,
break;
default:
return emitError(unknownLoc, "unsupported OpTypeFloat bitwidth: ")
<< operands[1];
<< bitWidth;
}

if (operands.size() == 3) {
if (spirv::FPEncoding(operands[2]) != spirv::FPEncoding::BFloat16KHR) {
return emitError(unknownLoc, "unsupported OpTypeFloat FP encoding: ")
<< operands[2];
}
if (bitWidth != 16) {
return emitError(unknownLoc,
"invalid OpTypeFloat bitwidth for bfloat16 encoding: ")
<< bitWidth << " (expected 16)";
}
floatTy = opBuilder.getBF16Type();
}

typeMap[operands[0]] = floatTy;
} break;
case spirv::Opcode::OpTypeVector: {
Expand Down
3 changes: 3 additions & 0 deletions mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -523,6 +523,9 @@ LogicalResult Serializer::prepareBasicType(
if (auto floatType = dyn_cast<FloatType>(type)) {
typeEnum = spirv::Opcode::OpTypeFloat;
operands.push_back(floatType.getWidth());
if (floatType.isBF16()) {
operands.push_back(static_cast<uint32_t>(spirv::FPEncoding::BFloat16KHR));
}
return success();
}

Expand Down
18 changes: 6 additions & 12 deletions mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,12 @@ func.func @float16(%arg0: f16) { return }
// NOEMU-SAME: f64
func.func @float64(%arg0: f64) { return }

// CHECK-LABEL: spirv.func @bfloat16
// CHECK-SAME: f32
// NOEMU-LABEL: func.func @bfloat16
// NOEMU-SAME: bf16
func.func @bfloat16(%arg0: bf16) { return }

// f80 is not supported by SPIR-V.
// CHECK-LABEL: func.func @float80
// CHECK-SAME: f80
Expand Down Expand Up @@ -206,18 +212,6 @@ func.func @float64(%arg0: f64) { return }

// -----

// Check that bf16 is not supported.
module attributes {
spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [], []>, #spirv.resource_limits<>>
} {

// CHECK-NOT: spirv.func @bf16_type
func.func @bf16_type(%arg0: bf16) { return }

} // end module

// -----

//===----------------------------------------------------------------------===//
// Complex types
//===----------------------------------------------------------------------===//
Expand Down
56 changes: 56 additions & 0 deletions mlir/test/Dialect/SPIRV/IR/cast-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,14 @@ func.func @convert_f_to_s_vector(%arg0 : vector<3xf32>) -> vector<3xi32> {

// -----

func.func @convert_bf16_to_s32_scalar(%arg0 : bf16) -> i32 {
// CHECK: {{%.*}} = spirv.ConvertFToS {{%.*}} : bf16 to i32
%0 = spirv.ConvertFToS %arg0 : bf16 to i32
spirv.ReturnValue %0 : i32
}

// -----

//===----------------------------------------------------------------------===//
// spirv.ConvertFToU
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -146,6 +154,14 @@ func.func @convert_f_to_u.coopmatrix(%arg0 : !spirv.coopmatrix<8x16xf32, Subgrou

// -----

func.func @convert_bf16_to_u32_scalar(%arg0 : bf16) -> i32 {
// CHECK: {{%.*}} = spirv.ConvertFToU {{%.*}} : bf16 to i32
%0 = spirv.ConvertFToU %arg0 : bf16 to i32
spirv.ReturnValue %0 : i32
}

// -----

//===----------------------------------------------------------------------===//
// spirv.ConvertSToF
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -174,6 +190,14 @@ func.func @convert_s_to_f_vector(%arg0 : vector<3xi32>) -> vector<3xf32> {

// -----

func.func @convert_s32_to_bf16_scalar(%arg0 : i32) -> bf16 {
// CHECK: {{%.*}} = spirv.ConvertSToF {{%.*}} : i32 to bf16
%0 = spirv.ConvertSToF %arg0 : i32 to bf16
spirv.ReturnValue %0 : bf16
}

// -----

//===----------------------------------------------------------------------===//
// spirv.ConvertUToF
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -202,6 +226,14 @@ func.func @convert_u_to_f_vector(%arg0 : vector<3xi32>) -> vector<3xf32> {

// -----

func.func @convert_u32_to_bf16_scalar(%arg0 : i32) -> bf16 {
// CHECK: {{%.*}} = spirv.ConvertUToF {{%.*}} : i32 to bf16
%0 = spirv.ConvertUToF %arg0 : i32 to bf16
spirv.ReturnValue %0 : bf16
}

// -----

//===----------------------------------------------------------------------===//
// spirv.FConvert
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -238,6 +270,30 @@ func.func @f_convert_vector(%arg0 : f32) -> f32 {

// -----

func.func @f_convert_bf16_to_f32_scalar(%arg0 : bf16) -> f32 {
// CHECK: {{%.*}} = spirv.FConvert {{%.*}} : bf16 to f32
%0 = spirv.FConvert %arg0 : bf16 to f32
spirv.ReturnValue %0 : f32
}

// -----

func.func @f_convert_f32_to_bf16_vector(%arg0 : vector<3xf32>) -> vector<3xbf16> {
// CHECK: {{%.*}} = spirv.FConvert {{%.*}} : vector<3xf32> to vector<3xbf16>
%0 = spirv.FConvert %arg0 : vector<3xf32> to vector<3xbf16>
spirv.ReturnValue %0 : vector<3xbf16>
}

// -----

func.func @f_convert_f32_to_bf16_coop_matrix(%arg0 : !spirv.coopmatrix<8x16xf32, Subgroup, MatrixA>) -> !spirv.coopmatrix<8x16xbf16, Subgroup, MatrixA> {
// CHECK: {{%.*}} = spirv.FConvert {{%.*}} : !spirv.coopmatrix<8x16xf32, Subgroup, MatrixA> to !spirv.coopmatrix<8x16xbf16, Subgroup, MatrixA>
%0 = spirv.FConvert %arg0 : !spirv.coopmatrix<8x16xf32, Subgroup, MatrixA> to !spirv.coopmatrix<8x16xbf16, Subgroup, MatrixA>
spirv.ReturnValue %0 : !spirv.coopmatrix<8x16xbf16, Subgroup, MatrixA>
}

// -----

//===----------------------------------------------------------------------===//
// spirv.SConvert
//===----------------------------------------------------------------------===//
Expand Down
8 changes: 3 additions & 5 deletions mlir/test/Dialect/SPIRV/IR/types.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ func.func private @vector_array_type(!spirv.array< 32 x vector<4xf32> >) -> ()
// CHECK: func private @array_type_stride(!spirv.array<4 x !spirv.array<4 x f32, stride=4>, stride=128>)
func.func private @array_type_stride(!spirv.array< 4 x !spirv.array<4 x f32, stride=4>, stride = 128>) -> ()

// CHECK: func private @vector_array_type_bf16(!spirv.array<32 x vector<4xbf16>>)
func.func private @vector_array_type_bf16(!spirv.array<32 x vector<4xbf16> >) -> ()

// -----

// expected-error @+1 {{expected '<'}}
Expand Down Expand Up @@ -57,11 +60,6 @@ func.func private @tensor_type(!spirv.array<4xtensor<4xf32>>) -> ()

// -----

// expected-error @+1 {{cannot use 'bf16' to compose SPIR-V types}}
func.func private @bf16_type(!spirv.array<4xbf16>) -> ()

// -----

// expected-error @+1 {{only 1/8/16/32/64-bit integer type allowed but found 'i256'}}
func.func private @i256_type(!spirv.array<4xi256>) -> ()

Expand Down
Loading
Loading