diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp index 0ec7129a40a66..2e6a16ddbfdaa 100644 --- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp +++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp @@ -339,13 +339,6 @@ struct ContractionLowering : public OpRewritePattern { if (!isRowMajorMatmul(contractOp.getIndexingMapsAttr())) return rewriter.notifyMatchFailure(contractOp, "Invalid indexing maps"); - // TODO: Update shape validation to be target aware. - auto accShape = accType.getShape(); - int64_t dimN = accShape[1]; - if (dimN != 8 && dimN != 16) - return rewriter.notifyMatchFailure(contractOp, - "Invalid operand dimensions"); - auto dpasOp = rewriter.create( loc, TypeRange{contractOp.getResultType()}, ValueRange{lhs, rhs, acc}); rewriter.replaceOp(contractOp, dpasOp); diff --git a/mlir/test/Conversion/VectorToXeGPU/contract-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/contract-to-xegpu.mlir index 8857ac204adca..38bda39d3aca2 100644 --- a/mlir/test/Conversion/VectorToXeGPU/contract-to-xegpu.mlir +++ b/mlir/test/Conversion/VectorToXeGPU/contract-to-xegpu.mlir @@ -48,6 +48,34 @@ func.func @dpas_gemm_i8(%lhs: vector<8x32xi8>, %rhs: vector<32x16xi8>, // ----- +// No restriction on vector sizes to allow capturing workgroup-sized operations. +// The operations can then be progressively resized through distribution down +// to hardware compatible sizes. + +#map = affine_map<(d0, d1, d2) -> (d0, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d2, d1)> +#map2 = affine_map<(d0, d1, d2) -> (d0, d1)> +func.func @dpas_large_dims(%lhs: vector<128x512xf16>, %rhs: vector<512x256xf16>, + %acc: vector<128x256xf32>) -> vector<128x256xf32> { + %3 = vector.contract + {indexing_maps = [#map, #map1, #map2], + iterator_types = ["parallel", "parallel", "reduction"], + kind = #vector.kind} %lhs, %rhs, %acc + : vector<128x512xf16>, vector<512x256xf16> into vector<128x256xf32> + return %3 : vector<128x256xf32> +} + +// CHECK-LABEL: @dpas_large_dims( +// CHECK-SAME: %[[LHS:.+]]: vector<128x512xf16>, +// CHECK-SAME: %[[RHS:.+]]: vector<512x256xf16>, +// CHECK-SAME: %[[ACC:.+]]: vector<128x256xf32> +// CHECK: %[[DPAS:.+]] = xegpu.dpas +// CHECK-SAME: %[[LHS]], %[[RHS]], %[[ACC]] +// CHECK-SAME: {{.*}}-> vector<128x256xf32> +// CHECK: return %[[DPAS]] + +// ----- + // For simplicity, only plain data layouts are currently supported. // VNNI packing is applied later as a separate lowering step. @@ -138,21 +166,3 @@ func.func @negative_gemm_transpose_b(%lhs: vector<8x16xf16>, %rhs: vector<16x16x // CHECK-LABEL: @negative_gemm_transpose_b( // CHECK: vector.contract - -// ----- - -#map = affine_map<(d0, d1, d2) -> (d0, d2)> -#map1 = affine_map<(d0, d1, d2) -> (d2, d1)> -#map2 = affine_map<(d0, d1, d2) -> (d0, d1)> -func.func @negative_n_dim_size(%lhs: vector<8x16xf16>, %rhs: vector<16x32xf16>, - %acc: vector<8x32xf32>) -> vector<8x32xf32> { - %3 = vector.contract - {indexing_maps = [#map, #map1, #map2], - iterator_types = ["parallel", "parallel", "reduction"], - kind = #vector.kind} %lhs, %rhs, %acc - : vector<8x16xf16>, vector<16x32xf16> into vector<8x32xf32> - return %3 : vector<8x32xf32> -} - -// CHECK-LABEL: @negative_n_dim_size( -// CHECK: vector.contract