Skip to content

Conversation

@Hsiangkai
Copy link
Contributor

When rewriting multiple CompositeInserts to CompositeConstruct, we need to know the number of elements of the result type. However, we cannot query the number of elements for cooperative matrix types.

When rewriting multiple CompositeInserts to CompositeConstruct, we need
to know the number of elements of the result type. However, we cannot
query the number of elements for cooperative matrix types.
@llvmbot
Copy link
Member

llvmbot commented Apr 29, 2025

@llvm/pr-subscribers-mlir

Author: Hsiangkai Wang (Hsiangkai)

Changes

When rewriting multiple CompositeInserts to CompositeConstruct, we need to know the number of elements of the result type. However, we cannot query the number of elements for cooperative matrix types.


Full diff: https://github.com/llvm/llvm-project/pull/137837.diff

1 Files Affected:

  • (modified) mlir/lib/Dialect/SPIRV/Transforms/RewriteInsertsPass.cpp (+3)
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/RewriteInsertsPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/RewriteInsertsPass.cpp
index f38282f57a2c3..bc3d0429efd19 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/RewriteInsertsPass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/RewriteInsertsPass.cpp
@@ -84,6 +84,9 @@ void RewriteInsertsPass::runOnOperation() {
 LogicalResult RewriteInsertsPass::collectInsertionChain(
     spirv::CompositeInsertOp op,
     SmallVectorImpl<spirv::CompositeInsertOp> &insertions) {
+  if (llvm::isa<spirv::CooperativeMatrixType>(op.getComposite().getType()))
+    return failure();
+
   auto indicesArrayAttr = cast<ArrayAttr>(op.getIndices());
   // TODO: handle nested composite object.
   if (indicesArrayAttr.size() == 1) {

@llvmbot
Copy link
Member

llvmbot commented Apr 29, 2025

@llvm/pr-subscribers-mlir-spirv

Author: Hsiangkai Wang (Hsiangkai)

Changes

When rewriting multiple CompositeInserts to CompositeConstruct, we need to know the number of elements of the result type. However, we cannot query the number of elements for cooperative matrix types.


Full diff: https://github.com/llvm/llvm-project/pull/137837.diff

1 Files Affected:

  • (modified) mlir/lib/Dialect/SPIRV/Transforms/RewriteInsertsPass.cpp (+3)
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/RewriteInsertsPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/RewriteInsertsPass.cpp
index f38282f57a2c3..bc3d0429efd19 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/RewriteInsertsPass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/RewriteInsertsPass.cpp
@@ -84,6 +84,9 @@ void RewriteInsertsPass::runOnOperation() {
 LogicalResult RewriteInsertsPass::collectInsertionChain(
     spirv::CompositeInsertOp op,
     SmallVectorImpl<spirv::CompositeInsertOp> &insertions) {
+  if (llvm::isa<spirv::CooperativeMatrixType>(op.getComposite().getType()))
+    return failure();
+
   auto indicesArrayAttr = cast<ArrayAttr>(op.getIndices());
   // TODO: handle nested composite object.
   if (indicesArrayAttr.size() == 1) {

Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we have a test?

LogicalResult RewriteInsertsPass::collectInsertionChain(
spirv::CompositeInsertOp op,
SmallVectorImpl<spirv::CompositeInsertOp> &insertions) {
if (llvm::isa<spirv::CooperativeMatrixType>(op.getComposite().getType()))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (llvm::isa<spirv::CooperativeMatrixType>(op.getComposite().getType()))
if (isa<spirv::CooperativeMatrixType>(op.getComposite().getType()))

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@Hsiangkai
Copy link
Contributor Author

Can we have a test?

I added a test to ensure it will not crash when dealing with coopmma values.

@Hsiangkai
Copy link
Contributor Author

Ping.

@Hsiangkai Hsiangkai requested a review from krzysz00 May 20, 2025 09:18
Copy link
Contributor

@krzysz00 krzysz00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

@Hsiangkai Hsiangkai merged commit 9f1da90 into llvm:main May 21, 2025
11 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants