-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[mlir][xegpu] Remove vector contract to dpas size restriction #147470
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][xegpu] Remove vector contract to dpas size restriction #147470
Conversation
Removes contraction shape check to allow representing large workgroup-level workloads in preparation for distribtion.
|
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-gpu Author: Adam Siemieniuk (adam-smnk) ChangesRemoves contraction shape check to allow representing large workgroup-level workloads in preparation for distribution. Full diff: https://github.com/llvm/llvm-project/pull/147470.diff 2 Files Affected:
diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
index 0ec7129a40a66..2e6a16ddbfdaa 100644
--- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
+++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
@@ -339,13 +339,6 @@ struct ContractionLowering : public OpRewritePattern<vector::ContractionOp> {
if (!isRowMajorMatmul(contractOp.getIndexingMapsAttr()))
return rewriter.notifyMatchFailure(contractOp, "Invalid indexing maps");
- // TODO: Update shape validation to be target aware.
- auto accShape = accType.getShape();
- int64_t dimN = accShape[1];
- if (dimN != 8 && dimN != 16)
- return rewriter.notifyMatchFailure(contractOp,
- "Invalid operand dimensions");
-
auto dpasOp = rewriter.create<xegpu::DpasOp>(
loc, TypeRange{contractOp.getResultType()}, ValueRange{lhs, rhs, acc});
rewriter.replaceOp(contractOp, dpasOp);
diff --git a/mlir/test/Conversion/VectorToXeGPU/contract-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/contract-to-xegpu.mlir
index 8857ac204adca..38bda39d3aca2 100644
--- a/mlir/test/Conversion/VectorToXeGPU/contract-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/contract-to-xegpu.mlir
@@ -48,6 +48,34 @@ func.func @dpas_gemm_i8(%lhs: vector<8x32xi8>, %rhs: vector<32x16xi8>,
// -----
+// No restriction on vector sizes to allow capturing workgroup-sized operations.
+// The operations can then be progressively resized through distribution down
+// to hardware compatible sizes.
+
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+func.func @dpas_large_dims(%lhs: vector<128x512xf16>, %rhs: vector<512x256xf16>,
+ %acc: vector<128x256xf32>) -> vector<128x256xf32> {
+ %3 = vector.contract
+ {indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction"],
+ kind = #vector.kind<add>} %lhs, %rhs, %acc
+ : vector<128x512xf16>, vector<512x256xf16> into vector<128x256xf32>
+ return %3 : vector<128x256xf32>
+}
+
+// CHECK-LABEL: @dpas_large_dims(
+// CHECK-SAME: %[[LHS:.+]]: vector<128x512xf16>,
+// CHECK-SAME: %[[RHS:.+]]: vector<512x256xf16>,
+// CHECK-SAME: %[[ACC:.+]]: vector<128x256xf32>
+// CHECK: %[[DPAS:.+]] = xegpu.dpas
+// CHECK-SAME: %[[LHS]], %[[RHS]], %[[ACC]]
+// CHECK-SAME: {{.*}}-> vector<128x256xf32>
+// CHECK: return %[[DPAS]]
+
+// -----
+
// For simplicity, only plain data layouts are currently supported.
// VNNI packing is applied later as a separate lowering step.
@@ -138,21 +166,3 @@ func.func @negative_gemm_transpose_b(%lhs: vector<8x16xf16>, %rhs: vector<16x16x
// CHECK-LABEL: @negative_gemm_transpose_b(
// CHECK: vector.contract
-
-// -----
-
-#map = affine_map<(d0, d1, d2) -> (d0, d2)>
-#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
-#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
-func.func @negative_n_dim_size(%lhs: vector<8x16xf16>, %rhs: vector<16x32xf16>,
- %acc: vector<8x32xf32>) -> vector<8x32xf32> {
- %3 = vector.contract
- {indexing_maps = [#map, #map1, #map2],
- iterator_types = ["parallel", "parallel", "reduction"],
- kind = #vector.kind<add>} %lhs, %rhs, %acc
- : vector<8x16xf16>, vector<16x32xf16> into vector<8x32xf32>
- return %3 : vector<8x32xf32>
-}
-
-// CHECK-LABEL: @negative_n_dim_size(
-// CHECK: vector.contract
|
chencha3
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thanks for your efforts.
Removes contraction shape check to allow representing large workgroup-level workloads in preparation for distribution.