Skip to content

Commit bf9c0ab

Browse files
committed
save work
1 parent a72ec25 commit bf9c0ab

File tree

1 file changed

+13
-3
lines changed

1 file changed

+13
-3
lines changed

mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
1414
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
15+
#include <cstdint>
1516
#include <numeric>
1617

1718
using namespace mlir;
@@ -64,10 +65,19 @@ mlir::xegpu::getDistributedVectorType(xegpu::TensorDescType tdescTy) {
6465
FailureOr<VectorType>
6566
mlir::xegpu::getDistributedVectorType(VectorType originalType,
6667
xegpu::LayoutAttr layout) {
67-
auto shape = originalType.getShape();
68+
int64_t rank = originalType.getRank();
69+
/// Distributed vector type is only supported for 1D, 2D and 3D vectors.
70+
if (rank < 1 || rank > 3)
71+
return failure();
72+
ArrayRef<int64_t> shape = originalType.getShape();
73+
/// arrayLength is 1 for 1D and 2D vectors, and equal to the first dimension
74+
/// of the 3D vector.
75+
int arrayLength = 1;
76+
if (rank == 3)
77+
arrayLength = shape[0];
6878
auto helperTdescTy = xegpu::TensorDescType::get(
69-
shape, originalType.getElementType(),
70-
/*array_length=*/1, /*boundary_check=*/true,
79+
shape, originalType.getElementType(), arrayLength,
80+
/*boundary_check=*/true,
7181
/*memory_space=*/xegpu::MemorySpace::Global, layout);
7282
return xegpu::getDistributedVectorType(helperTdescTy);
7383
}

0 commit comments

Comments
 (0)