File tree Expand file tree Collapse file tree 1 file changed +13
-3
lines changed
mlir/lib/Dialect/XeGPU/Utils Expand file tree Collapse file tree 1 file changed +13
-3
lines changed Original file line number Diff line number Diff line change 1212
1313#include " mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
1414#include " mlir/Dialect/XeGPU/IR/XeGPU.h"
15+ #include < cstdint>
1516#include < numeric>
1617
1718using namespace mlir ;
@@ -64,10 +65,19 @@ mlir::xegpu::getDistributedVectorType(xegpu::TensorDescType tdescTy) {
6465FailureOr<VectorType>
6566mlir::xegpu::getDistributedVectorType (VectorType originalType,
6667 xegpu::LayoutAttr layout) {
67- auto shape = originalType.getShape ();
68+ int64_t rank = originalType.getRank ();
69+ // / Distributed vector type is only supported for 1D, 2D and 3D vectors.
70+ if (rank < 1 || rank > 3 )
71+ return failure ();
72+ ArrayRef<int64_t > shape = originalType.getShape ();
73+ // / arrayLength is 1 for 1D and 2D vectors, and equal to the first dimension
74+ // / of the 3D vector.
75+ int arrayLength = 1 ;
76+ if (rank == 3 )
77+ arrayLength = shape[0 ];
6878 auto helperTdescTy = xegpu::TensorDescType::get (
69- shape, originalType.getElementType (),
70- /* array_length= */ 1 , /* boundary_check=*/ true ,
79+ shape, originalType.getElementType (), arrayLength,
80+ /* boundary_check=*/ true ,
7181 /* memory_space=*/ xegpu::MemorySpace::Global, layout);
7282 return xegpu::getDistributedVectorType (helperTdescTy);
7383}
You can’t perform that action at this time.
0 commit comments