Skip to content

Conversation

@adam-smnk
Copy link
Contributor

@adam-smnk adam-smnk commented Jul 8, 2025

Removes contraction shape check to allow representing large workgroup-level workloads in preparation for distribution.

Removes contraction shape check to allow representing large
workgroup-level workloads in preparation for distribtion.
@llvmbot
Copy link
Member

llvmbot commented Jul 8, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-gpu

Author: Adam Siemieniuk (adam-smnk)

Changes

Removes contraction shape check to allow representing large workgroup-level workloads in preparation for distribution.


Full diff: https://github.com/llvm/llvm-project/pull/147470.diff

2 Files Affected:

  • (modified) mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp (-7)
  • (modified) mlir/test/Conversion/VectorToXeGPU/contract-to-xegpu.mlir (+28-18)
diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
index 0ec7129a40a66..2e6a16ddbfdaa 100644
--- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
+++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
@@ -339,13 +339,6 @@ struct ContractionLowering : public OpRewritePattern<vector::ContractionOp> {
     if (!isRowMajorMatmul(contractOp.getIndexingMapsAttr()))
       return rewriter.notifyMatchFailure(contractOp, "Invalid indexing maps");
 
-    // TODO: Update shape validation to be target aware.
-    auto accShape = accType.getShape();
-    int64_t dimN = accShape[1];
-    if (dimN != 8 && dimN != 16)
-      return rewriter.notifyMatchFailure(contractOp,
-                                         "Invalid operand dimensions");
-
     auto dpasOp = rewriter.create<xegpu::DpasOp>(
         loc, TypeRange{contractOp.getResultType()}, ValueRange{lhs, rhs, acc});
     rewriter.replaceOp(contractOp, dpasOp);
diff --git a/mlir/test/Conversion/VectorToXeGPU/contract-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/contract-to-xegpu.mlir
index 8857ac204adca..38bda39d3aca2 100644
--- a/mlir/test/Conversion/VectorToXeGPU/contract-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/contract-to-xegpu.mlir
@@ -48,6 +48,34 @@ func.func @dpas_gemm_i8(%lhs: vector<8x32xi8>, %rhs: vector<32x16xi8>,
 
 // -----
 
+// No restriction on vector sizes to allow capturing workgroup-sized operations.
+// The operations can then be progressively resized through distribution down
+// to hardware compatible sizes.
+
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @dpas_large_dims(%lhs: vector<128x512xf16>, %rhs: vector<512x256xf16>,
+    %acc: vector<128x256xf32>) -> vector<128x256xf32> {
+  %3 = vector.contract
+    {indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "reduction"],
+    kind = #vector.kind<add>} %lhs, %rhs, %acc
+    : vector<128x512xf16>, vector<512x256xf16> into vector<128x256xf32>
+  return %3 : vector<128x256xf32>
+}
+
+// CHECK-LABEL: @dpas_large_dims(
+// CHECK-SAME:  %[[LHS:.+]]: vector<128x512xf16>,
+// CHECK-SAME:  %[[RHS:.+]]: vector<512x256xf16>,
+// CHECK-SAME:  %[[ACC:.+]]: vector<128x256xf32>
+// CHECK:       %[[DPAS:.+]] = xegpu.dpas
+// CHECK-SAME:    %[[LHS]], %[[RHS]], %[[ACC]]
+// CHECK-SAME:    {{.*}}-> vector<128x256xf32>
+// CHECK:       return %[[DPAS]]
+
+// -----
+
 // For simplicity, only plain data layouts are currently supported.
 // VNNI packing is applied later as a separate lowering step.
 
@@ -138,21 +166,3 @@ func.func @negative_gemm_transpose_b(%lhs: vector<8x16xf16>, %rhs: vector<16x16x
 
 // CHECK-LABEL: @negative_gemm_transpose_b(
 // CHECK:       vector.contract
-
-// -----
-
-#map = affine_map<(d0, d1, d2) -> (d0, d2)>
-#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
-#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
-func.func @negative_n_dim_size(%lhs: vector<8x16xf16>, %rhs: vector<16x32xf16>,
-    %acc: vector<8x32xf32>) -> vector<8x32xf32> {
-  %3 = vector.contract
-    {indexing_maps = [#map, #map1, #map2],
-    iterator_types = ["parallel", "parallel", "reduction"],
-    kind = #vector.kind<add>} %lhs, %rhs, %acc
-    : vector<8x16xf16>, vector<16x32xf16> into vector<8x32xf32>
-  return %3 : vector<8x32xf32>
-}
-
-// CHECK-LABEL: @negative_n_dim_size(
-// CHECK:       vector.contract

@adam-smnk
Copy link
Contributor Author

FYI @tkarna
The lowering will be relaxed further in a follow-up PR after #145916 lands

Copy link
Contributor

@chencha3 chencha3 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks for your efforts.

@adam-smnk adam-smnk merged commit 06ae0c2 into llvm:main Jul 9, 2025
12 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants