Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -445,12 +445,12 @@ def SPIRV_DotOp : SPIRV_Op<"Dot",
}];

let arguments = (ins
SPIRV_VectorOf<SPIRV_Float>:$vector1,
SPIRV_VectorOf<SPIRV_Float>:$vector2
SPIRV_VectorOf<SPIRV_FloatOrBFloat16>:$vector1,
SPIRV_VectorOf<SPIRV_FloatOrBFloat16>:$vector2
);

let results = (outs
SPIRV_Float:$result
SPIRV_FloatOrBFloat16:$result
);

let assemblyFormat = "operands attr-dict `:` type($vector1) `->` type($result)";
Expand Down
40 changes: 35 additions & 5 deletions mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,7 @@ def SPV_KHR_subgroup_rotate : I32EnumAttrCase<"SPV_KHR_subgroup
def SPV_KHR_non_semantic_info : I32EnumAttrCase<"SPV_KHR_non_semantic_info", 29>;
def SPV_KHR_terminate_invocation : I32EnumAttrCase<"SPV_KHR_terminate_invocation", 30>;
def SPV_KHR_cooperative_matrix : I32EnumAttrCase<"SPV_KHR_cooperative_matrix", 31>;
def SPV_KHR_bfloat16 : I32EnumAttrCase<"SPV_KHR_bfloat16", 32>;

def SPV_EXT_demote_to_helper_invocation : I32EnumAttrCase<"SPV_EXT_demote_to_helper_invocation", 1000>;
def SPV_EXT_descriptor_indexing : I32EnumAttrCase<"SPV_EXT_descriptor_indexing", 1001>;
Expand Down Expand Up @@ -436,7 +437,7 @@ def SPIRV_ExtensionAttr :
SPV_KHR_fragment_shader_barycentric, SPV_KHR_ray_cull_mask,
SPV_KHR_uniform_group_instructions, SPV_KHR_subgroup_rotate,
SPV_KHR_non_semantic_info, SPV_KHR_terminate_invocation,
SPV_KHR_cooperative_matrix,
SPV_KHR_cooperative_matrix, SPV_KHR_bfloat16,
SPV_EXT_demote_to_helper_invocation, SPV_EXT_descriptor_indexing,
SPV_EXT_fragment_fully_covered, SPV_EXT_fragment_invocation_density,
SPV_EXT_fragment_shader_interlock, SPV_EXT_physical_storage_buffer,
Expand Down Expand Up @@ -1412,6 +1413,23 @@ def SPIRV_C_ShaderStereoViewNV : I32EnumAttrCase<"Shade
Extension<[SPV_NV_stereo_view_rendering]>
];
}
def SPIRV_C_BFloat16TypeKHR : I32EnumAttrCase<"BFloat16TypeKHR", 5116> {
list<Availability> availability = [
Extension<[SPV_KHR_bfloat16]>
];
}
def SPIRV_C_BFloat16DotProductKHR : I32EnumAttrCase<"BFloat16DotProductKHR", 5117> {
list<I32EnumAttrCase> implies = [SPIRV_C_BFloat16TypeKHR];
list<Availability> availability = [
Extension<[SPV_KHR_bfloat16]>
];
}
def SPIRV_C_BFloat16CooperativeMatrixKHR : I32EnumAttrCase<"BFloat16CooperativeMatrixKHR", 5118> {
list<I32EnumAttrCase> implies = [SPIRV_C_BFloat16TypeKHR, SPIRV_C_CooperativeMatrixKHR];
list<Availability> availability = [
Extension<[SPV_KHR_bfloat16]>
];
}

def SPIRV_C_Bfloat16ConversionINTEL : I32EnumAttrCase<"Bfloat16ConversionINTEL", 6115> {
list<Availability> availability = [
Expand Down Expand Up @@ -1518,7 +1536,8 @@ def SPIRV_CapabilityAttr :
SPIRV_C_StorageTexelBufferArrayNonUniformIndexing,
SPIRV_C_ShaderViewportIndexLayerEXT, SPIRV_C_ShaderViewportMaskNV,
SPIRV_C_ShaderStereoViewNV, SPIRV_C_Bfloat16ConversionINTEL,
SPIRV_C_CacheControlsINTEL
SPIRV_C_CacheControlsINTEL, SPIRV_C_BFloat16TypeKHR,
SPIRV_C_BFloat16DotProductKHR, SPIRV_C_BFloat16CooperativeMatrixKHR,
]>;

def SPIRV_AM_Logical : I32EnumAttrCase<"Logical", 0>;
Expand Down Expand Up @@ -3217,6 +3236,16 @@ def SPIRV_ExecutionModelAttr :
SPIRV_EM_TaskEXT, SPIRV_EM_MeshEXT
]>;

def SPIRV_FPE_BFloat16KHR : I32EnumAttrCase<"BFloat16KHR", 0> {
list<Availability> availability = [
Capability<[SPIRV_C_BFloat16TypeKHR]>
];
}
def SPIRV_FPEncodingAttr :
SPIRV_I32EnumAttr<"FPEncoding", "valid SPIR-V FPEncoding", "f_p_encoding", [
SPIRV_FPE_BFloat16KHR
]>;

def SPIRV_FC_None : I32BitEnumAttrCaseNone<"None">;
def SPIRV_FC_Inline : I32BitEnumAttrCaseBit<"Inline", 0>;
def SPIRV_FC_DontInline : I32BitEnumAttrCaseBit<"DontInline", 1>;
Expand Down Expand Up @@ -4163,8 +4192,9 @@ def SPIRV_Int32 : TypeAlias<I32, "Int32">;
def SPIRV_Float32 : TypeAlias<F32, "Float32">;
def SPIRV_Float : FloatOfWidths<[16, 32, 64]>;
def SPIRV_Float16or32 : FloatOfWidths<[16, 32]>;
def SPIRV_FloatOrBFloat16 : AnyTypeOf<[SPIRV_Float, BF16]>;
def SPIRV_Vector : VectorOfLengthAndType<[2, 3, 4, 8, 16],
[SPIRV_Bool, SPIRV_Integer, SPIRV_Float]>;
[SPIRV_Bool, SPIRV_Integer, SPIRV_FloatOrBFloat16]>;
Copy link
Contributor Author

@fairywreath fairywreath May 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can argue that bf16 should be part of SPIRV_Float. The problem here is that bf16 usage in SPIRV is very limited while SPIRV_Float(i.e regular floats) is used widely in the codebase for other ops(eg. texture sampling and regular arithmetic insts). I chose to leave SPIRV_Float to minimize the amount of changes(and to not introduce something like SPIRV_ArithmeticFloat). Please let me know if you think there is a cleaner solution to this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with this, it might be better to leave them alone. what about something like this:

def SPIRV_BFloat16KHR : TypeAlias<BF16, "BFloat16">;
def SPIRV_Float : AnyTypeOf<[F16, F32, F64]>;
def SPIRV_Float16or32 : AnyTypeOf<[F16, F32]>;
// Use this type for all kinds of floats.
def SPIRV_AnyFloat : AnyTypeOf<[SPIRV_BFloat16KHR, SPIRV_Float]>;
.....
def SPIRV_Vector : VectorOfLengthRangeAndType<[2, 0xFFFFFFFF],
                                       [SPIRV_Bool, SPIRV_Integer, SPIRV_Float, SPIRV_BFloat16KHR]>;
......

def SPIRV_Numerical : AnyTypeOf<[SPIRV_Integer, SPIRV_Float, SPIRV_BFloat16KHR]>;
.........

def SPIRV_Type : AnyTypeOf<[
    SPIRV_Void, SPIRV_Bool, SPIRV_Integer, SPIRV_Float, SPIRV_BFloat16KHR, SPIRV_Vector,
     SPIRV_AnyPtr, SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct,
     SPIRV_AnyCooperativeMatrix, SPIRV_AnyMatrix, SPIRV_AnySampledImage
   ]>;

// Component type check is done in the type parser for the following SPIR-V
// dialect-specific types so we use "Any" here.
def SPIRV_AnyPtr : DialectType<SPIRV_Dialect, SPIRV_IsPtrType,
Expand Down Expand Up @@ -4194,9 +4224,9 @@ def SPIRV_Composite :
AnyTypeOf<[SPIRV_Vector, SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct,
SPIRV_AnyCooperativeMatrix, SPIRV_AnyMatrix]>;
def SPIRV_Type : AnyTypeOf<[
SPIRV_Void, SPIRV_Bool, SPIRV_Integer, SPIRV_Float, SPIRV_Vector,
SPIRV_Void, SPIRV_Bool, SPIRV_Integer, SPIRV_FloatOrBFloat16, SPIRV_Vector,
SPIRV_AnyPtr, SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct,
SPIRV_AnyCooperativeMatrix, SPIRV_AnyMatrix, SPIRV_AnySampledImage
SPIRV_AnyCooperativeMatrix, SPIRV_AnyMatrix, SPIRV_AnySampledImage,
]>;

def SPIRV_SignedInt : SignedIntOfWidths<[8, 16, 32, 64]>;
Expand Down
12 changes: 6 additions & 6 deletions mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def SPIRV_BitcastOp : SPIRV_Op<"Bitcast", [Pure]> {

// -----

def SPIRV_ConvertFToSOp : SPIRV_CastOp<"ConvertFToS", SPIRV_Integer, SPIRV_Float, []> {
def SPIRV_ConvertFToSOp : SPIRV_CastOp<"ConvertFToS", SPIRV_Integer, SPIRV_FloatOrBFloat16, []> {
let summary = [{
Convert value numerically from floating point to signed integer, with
round toward 0.0.
Expand All @@ -111,7 +111,7 @@ def SPIRV_ConvertFToSOp : SPIRV_CastOp<"ConvertFToS", SPIRV_Integer, SPIRV_Float

// -----

def SPIRV_ConvertFToUOp : SPIRV_CastOp<"ConvertFToU", SPIRV_Integer, SPIRV_Float, []> {
def SPIRV_ConvertFToUOp : SPIRV_CastOp<"ConvertFToU", SPIRV_Integer, SPIRV_FloatOrBFloat16, []> {
let summary = [{
Convert value numerically from floating point to unsigned integer, with
round toward 0.0.
Expand All @@ -138,7 +138,7 @@ def SPIRV_ConvertFToUOp : SPIRV_CastOp<"ConvertFToU", SPIRV_Integer, SPIRV_Float
// -----

def SPIRV_ConvertSToFOp : SPIRV_CastOp<"ConvertSToF",
SPIRV_Float,
SPIRV_FloatOrBFloat16,
SPIRV_Integer,
[SignedOp]> {
let summary = [{
Expand All @@ -165,7 +165,7 @@ def SPIRV_ConvertSToFOp : SPIRV_CastOp<"ConvertSToF",
// -----

def SPIRV_ConvertUToFOp : SPIRV_CastOp<"ConvertUToF",
SPIRV_Float,
SPIRV_FloatOrBFloat16,
SPIRV_Integer,
[UnsignedOp]> {
let summary = [{
Expand All @@ -192,8 +192,8 @@ def SPIRV_ConvertUToFOp : SPIRV_CastOp<"ConvertUToF",
// -----

def SPIRV_FConvertOp : SPIRV_CastOp<"FConvert",
SPIRV_Float,
SPIRV_Float,
SPIRV_FloatOrBFloat16,
SPIRV_FloatOrBFloat16,
[UsableInSpecConstantOp]> {
let summary = [{
Convert value numerically from one floating-point width to another
Expand Down
5 changes: 1 addition & 4 deletions mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -175,10 +175,7 @@ static Type parseAndVerifyType(SPIRVDialect const &dialect,

// Check other allowed types
if (auto t = llvm::dyn_cast<FloatType>(type)) {
if (type.isBF16()) {
parser.emitError(typeLoc, "cannot use 'bf16' to compose SPIR-V types");
return Type();
}
// TODO: All float types are allowed for now, but this should be fixed.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will address this in a separate PR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please elaborate what needs to be fixed here?

Copy link
Contributor Author

@fairywreath fairywreath May 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current behavior does not error out on bitwidths that are invalid for SPIRV (eg. F80, F128) and non-standard formats (eg. E3M2). Do you think it's better to address this here or in a separate PR?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my opinion it's okay to address it later. In fact I think it's preferable. Currently the code doesn't do any checks anyway, other than checking for bf16, so adding a proper check would be out of scope of this PR.

} else if (auto t = llvm::dyn_cast<IntegerType>(type)) {
if (!ScalarType::isValid(t)) {
parser.emitError(typeLoc,
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -505,7 +505,7 @@ bool ScalarType::classof(Type type) {
}

bool ScalarType::isValid(FloatType type) {
return llvm::is_contained({16u, 32u, 64u}, type.getWidth()) && !type.isBF16();
return llvm::is_contained({16u, 32u, 64u}, type.getWidth());
}

bool ScalarType::isValid(IntegerType type) {
Expand Down
3 changes: 3 additions & 0 deletions mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -523,6 +523,9 @@ LogicalResult Serializer::prepareBasicType(
if (auto floatType = dyn_cast<FloatType>(type)) {
typeEnum = spirv::Opcode::OpTypeFloat;
operands.push_back(floatType.getWidth());
if (floatType.isBF16()) {
operands.push_back(static_cast<uint32_t>(spirv::FPEncoding::BFloat16KHR));
}
return success();
}

Expand Down
12 changes: 0 additions & 12 deletions mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -206,18 +206,6 @@ func.func @float64(%arg0: f64) { return }

// -----

// Check that bf16 is not supported.
module attributes {
spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [], []>, #spirv.resource_limits<>>
} {

// CHECK-NOT: spirv.func @bf16_type
func.func @bf16_type(%arg0: bf16) { return }

} // end module

// -----

//===----------------------------------------------------------------------===//
// Complex types
//===----------------------------------------------------------------------===//
Expand Down
9 changes: 8 additions & 1 deletion mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,13 @@ func.func @dot(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> f32 {

// -----

func.func @dot_bf16(%arg0: vector<4xbf16>, %arg1: vector<4xbf16>) -> bf16 {
%0 = spirv.Dot %arg0, %arg1 : vector<4xbf16> -> bf16
return %0 : bf16
}

// -----

// expected-note @+1 {{prior use here}}
func.func @dot(%arg0: vector<4xf32>, %arg1: vector<3xf32>) -> f32 {
// expected-error @+1 {{use of value '%arg1' expects different type than prior uses}}
Expand All @@ -283,7 +290,7 @@ func.func @dot(%arg0: vector<4xf32>, %arg1: vector<4xf32>) -> f16 {
// -----

func.func @dot(%arg0: vector<4xi32>, %arg1: vector<4xi32>) -> i32 {
// expected-error @+1 {{'spirv.Dot' op operand #0 must be vector of 16/32/64-bit float values of length 2/3/4/8/16}}
// expected-error @+1 {{'spirv.Dot' op operand #0 must be vector of 16/32/64-bit float or bfloat16 type values of length 2/3/4/8/16}}
%0 = spirv.Dot %arg0, %arg1 : vector<4xi32> -> i32
return %0 : i32
}
Expand Down
56 changes: 56 additions & 0 deletions mlir/test/Dialect/SPIRV/IR/cast-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,14 @@ func.func @convert_f_to_s_vector(%arg0 : vector<3xf32>) -> vector<3xi32> {

// -----

func.func @convert_bf16_to_s32_scalar(%arg0 : bf16) -> i32 {
// CHECK: {{%.*}} = spirv.ConvertFToS {{%.*}} : bf16 to i32
%0 = spirv.ConvertFToS %arg0 : bf16 to i32
spirv.ReturnValue %0 : i32
}

// -----

//===----------------------------------------------------------------------===//
// spirv.ConvertFToU
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -146,6 +154,14 @@ func.func @convert_f_to_u.coopmatrix(%arg0 : !spirv.coopmatrix<8x16xf32, Subgrou

// -----

func.func @convert_bf16_to_u32_scalar(%arg0 : bf16) -> i32 {
// CHECK: {{%.*}} = spirv.ConvertFToU {{%.*}} : bf16 to i32
%0 = spirv.ConvertFToU %arg0 : bf16 to i32
spirv.ReturnValue %0 : i32
}

// -----

//===----------------------------------------------------------------------===//
// spirv.ConvertSToF
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -174,6 +190,14 @@ func.func @convert_s_to_f_vector(%arg0 : vector<3xi32>) -> vector<3xf32> {

// -----

func.func @convert_s32_to_bf16_scalar(%arg0 : i32) -> bf16 {
// CHECK: {{%.*}} = spirv.ConvertSToF {{%.*}} : i32 to bf16
%0 = spirv.ConvertSToF %arg0 : i32 to bf16
spirv.ReturnValue %0 : bf16
}

// -----

//===----------------------------------------------------------------------===//
// spirv.ConvertUToF
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -202,6 +226,14 @@ func.func @convert_u_to_f_vector(%arg0 : vector<3xi32>) -> vector<3xf32> {

// -----

func.func @convert_u32_to_bf16_scalar(%arg0 : i32) -> bf16 {
// CHECK: {{%.*}} = spirv.ConvertUToF {{%.*}} : i32 to bf16
%0 = spirv.ConvertUToF %arg0 : i32 to bf16
spirv.ReturnValue %0 : bf16
}

// -----

//===----------------------------------------------------------------------===//
// spirv.FConvert
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -238,6 +270,30 @@ func.func @f_convert_vector(%arg0 : f32) -> f32 {

// -----

func.func @f_convert_bf16_to_f32_scalar(%arg0 : bf16) -> f32 {
// CHECK: {{%.*}} = spirv.FConvert {{%.*}} : bf16 to f32
%0 = spirv.FConvert %arg0 : bf16 to f32
spirv.ReturnValue %0 : f32
}

// -----

func.func @f_convert_f32_to_bf16_vector(%arg0 : vector<3xf32>) -> vector<3xbf16> {
// CHECK: {{%.*}} = spirv.FConvert {{%.*}} : vector<3xf32> to vector<3xbf16>
%0 = spirv.FConvert %arg0 : vector<3xf32> to vector<3xbf16>
spirv.ReturnValue %0 : vector<3xbf16>
}

// -----

func.func @f_convert_f32_to_bf16_coop_matrix(%arg0 : !spirv.coopmatrix<8x16xf32, Subgroup, MatrixA>) -> !spirv.coopmatrix<8x16xbf16, Subgroup, MatrixA> {
// CHECK: {{%.*}} = spirv.FConvert {{%.*}} : !spirv.coopmatrix<8x16xf32, Subgroup, MatrixA> to !spirv.coopmatrix<8x16xbf16, Subgroup, MatrixA>
%0 = spirv.FConvert %arg0 : !spirv.coopmatrix<8x16xf32, Subgroup, MatrixA> to !spirv.coopmatrix<8x16xbf16, Subgroup, MatrixA>
spirv.ReturnValue %0 : !spirv.coopmatrix<8x16xbf16, Subgroup, MatrixA>
}

// -----

//===----------------------------------------------------------------------===//
// spirv.SConvert
//===----------------------------------------------------------------------===//
Expand Down
29 changes: 29 additions & 0 deletions mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,15 @@ spirv.func @cooperative_matrix_load_memoperand(%ptr : !spirv.ptr<i32, StorageBuf
spirv.Return
}

// CHECK-LABEL: @cooperative_matrix_load_bf16
spirv.func @cooperative_matrix_load_bf16(%ptr : !spirv.ptr<bf16, StorageBuffer>, %stride : i32) "None" {
// CHECK: {{%.+}} = spirv.KHR.CooperativeMatrixLoad {{%.*}}, {{%.*}}, <RowMajor>
// CHECK-SAME: : !spirv.ptr<bf16, StorageBuffer>, i32 -> !spirv.coopmatrix<16x8xbf16, Workgroup, MatrixA>
%0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <RowMajor> :
!spirv.ptr<bf16, StorageBuffer>, i32 -> !spirv.coopmatrix<16x8xbf16, Workgroup, MatrixA>
spirv.Return
}

// CHECK-LABEL: @cooperative_matrix_load_vector_ptr_type
spirv.func @cooperative_matrix_load_vector_ptr_type(%ptr : !spirv.ptr<vector<4xi32>, StorageBuffer>, %stride : i32) "None" {
// CHECK: {{%.*}} = spirv.KHR.CooperativeMatrixLoad {{%.*}}, {{%.*}}, <RowMajor>, <Volatile> :
Expand Down Expand Up @@ -225,6 +234,26 @@ spirv.func @cooperative_matrix_muladd_f32(%a : !spirv.coopmatrix<4x4xf32, Subgro
spirv.Return
}

spirv.func @cooperative_matrix_muladd_bf16_bf16(%a : !spirv.coopmatrix<8x16xbf16, Subgroup, MatrixA>,
%b : !spirv.coopmatrix<16x4xbf16, Subgroup, MatrixB>,
%c : !spirv.coopmatrix<8x4xbf16, Subgroup, MatrixAcc>) "None" {
%r = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c :
!spirv.coopmatrix<8x16xbf16, Subgroup, MatrixA>,
!spirv.coopmatrix<16x4xbf16, Subgroup, MatrixB> ->
!spirv.coopmatrix<8x4xbf16, Subgroup, MatrixAcc>
spirv.Return
}

spirv.func @cooperative_matrix_muladd_bf16_f32(%a : !spirv.coopmatrix<8x16xbf16, Subgroup, MatrixA>,
%b : !spirv.coopmatrix<16x4xbf16, Subgroup, MatrixB>,
%c : !spirv.coopmatrix<8x4xf32, Subgroup, MatrixAcc>) "None" {
%r = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c :
!spirv.coopmatrix<8x16xbf16, Subgroup, MatrixA>,
!spirv.coopmatrix<16x4xbf16, Subgroup, MatrixB> ->
!spirv.coopmatrix<8x4xf32, Subgroup, MatrixAcc>
spirv.Return
}

spirv.func @cooperative_matrix_muladd_i8_i32(%a : !spirv.coopmatrix<8x16xi8, Subgroup, MatrixA>,
%b : !spirv.coopmatrix<16x4xi8, Subgroup, MatrixB>,
%c : !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>) "None" {
Expand Down
5 changes: 0 additions & 5 deletions mlir/test/Dialect/SPIRV/IR/types.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,6 @@ func.func private @tensor_type(!spirv.array<4xtensor<4xf32>>) -> ()

// -----

// expected-error @+1 {{cannot use 'bf16' to compose SPIR-V types}}
func.func private @bf16_type(!spirv.array<4xbf16>) -> ()

// -----

// expected-error @+1 {{only 1/8/16/32/64-bit integer type allowed but found 'i256'}}
func.func private @i256_type(!spirv.array<4xi256>) -> ()

Expand Down