diff --git a/mlir/lib/Dialect/SPIRV/Transforms/RewriteInsertsPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/RewriteInsertsPass.cpp index f38282f57a2c3..2e31172ab940b 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/RewriteInsertsPass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/RewriteInsertsPass.cpp @@ -84,6 +84,9 @@ void RewriteInsertsPass::runOnOperation() { LogicalResult RewriteInsertsPass::collectInsertionChain( spirv::CompositeInsertOp op, SmallVectorImpl &insertions) { + if (isa(op.getComposite().getType())) + return failure(); + auto indicesArrayAttr = cast(op.getIndices()); // TODO: handle nested composite object. if (indicesArrayAttr.size() == 1) { diff --git a/mlir/test/Dialect/SPIRV/Transforms/rewrite-inserts.mlir b/mlir/test/Dialect/SPIRV/Transforms/rewrite-inserts.mlir index 6d755be4f3987..a83c3f7d34693 100644 --- a/mlir/test/Dialect/SPIRV/Transforms/rewrite-inserts.mlir +++ b/mlir/test/Dialect/SPIRV/Transforms/rewrite-inserts.mlir @@ -29,3 +29,15 @@ spirv.module Logical GLSL450 { spirv.ReturnValue %3 : vector<3xf32> } } + +// ----- + +spirv.module Logical GLSL450 { + spirv.func @insertCoopMatrix(%value : f32) -> !spirv.coopmatrix<4x4xf32, Subgroup, MatrixA> "None" { + %0 = spirv.Undef : !spirv.coopmatrix<4x4xf32, Subgroup, MatrixA> + // CHECK: spirv.CompositeInsert {{%.*}}, {{%.*}} : f32 into !spirv.coopmatrix<4x4xf32, Subgroup, MatrixA> + %1 = spirv.CompositeInsert %value, %0[0 : i32] : f32 into !spirv.coopmatrix<4x4xf32, Subgroup, MatrixA> + + spirv.ReturnValue %1 : !spirv.coopmatrix<4x4xf32, Subgroup, MatrixA> + } +}