Skip to content

Commit 5c1c908

Browse files
committed
save work
1 parent b3e6dc5 commit 5c1c908

File tree

5 files changed

+51
-39
lines changed

5 files changed

+51
-39
lines changed

mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,4 +34,12 @@ class TensorDescType;
3434
#define GET_OP_CLASSES
3535
#include <mlir/Dialect/XeGPU/IR/XeGPU.h.inc>
3636

37+
namespace mlir {
38+
namespace xegpu {
39+
FailureOr<VectorType> getDistributedVectorType(VectorType originalType,
40+
LayoutAttr layout);
41+
FailureOr<VectorType> getDistributedVectorType(xegpu::TensorDescType tdescTy);
42+
} // namespace xegpu
43+
} // namespace mlir
44+
3745
#endif // MLIR_DIALECT_XEGPU_IR_XEGPU_H

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -189,11 +189,6 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
189189
return scatter_attr.getChunkSize().getInt();
190190
return 1;
191191
}
192-
193-
// This returns a vector type that represents the fragment of data owned by
194-
// a work item in SIMT mode if this tensor descriptor is used in a XeGPU
195-
// load/store operation.
196-
FailureOr<VectorType> getDistributedVectorType();
197192
}];
198193

199194
let hasCustomAssemblyFormat = true;

mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
1010
#include "mlir/IR/Builders.h"
11+
#include "mlir/IR/BuiltinTypes.h"
1112
#include "mlir/IR/DialectImplementation.h"
1213
#include "llvm/ADT/TypeSwitch.h"
1314
#include <numeric>
@@ -336,16 +337,17 @@ LogicalResult TensorDescType::verify(
336337
// * tensor_desc[1] % (lane_layout[1] × lane_data[1]) == 0
337338
// Distributed vector is a 1D vector with shape:
338339
// [fragment_size]
339-
FailureOr<VectorType> TensorDescType::getDistributedVectorType() {
340-
auto layout = llvm::dyn_cast_if_present<LayoutAttr>(getLayout());
340+
FailureOr<VectorType> getDistributedVectorType(xegpu::TensorDescType tdescTy) {
341+
auto layout = llvm::dyn_cast_if_present<LayoutAttr>(tdescTy.getLayout());
341342
// It only works for subgroup level layout, which only has lane_layout
342343
// and lane_data, and is to distribute a SIMD code into SIMT code.
343344
if (!layout || !layout.isSgLayout())
344345
return failure();
345346

346347
SmallVector<int64_t> laneData(layout.getLaneData().asArrayRef());
347348
SmallVector<int64_t> laneLayout(layout.getLaneLayout().asArrayRef());
348-
auto tdescShape = getShape();
349+
auto tdescShape = tdescTy.getShape();
350+
auto elementType = tdescTy.getElementType();
349351

350352
// compute sgSize by multiply elements of laneLayout
351353
// e.g. for 2D layout, sgSize = laneLayout[0] * laneLayout[1]
@@ -354,14 +356,14 @@ FailureOr<VectorType> TensorDescType::getDistributedVectorType() {
354356
std::multiplies<int64_t>());
355357

356358
// Case 1: regular loads/stores
357-
auto scatterAttr = getEncodingAsScatterTensorDescAttr();
359+
auto scatterAttr = tdescTy.getEncodingAsScatterTensorDescAttr();
358360
if (scatterAttr) {
359361
auto chunkSize = scatterAttr.getChunkSize().getInt();
360362
// Verify if the first dimension of the tensor descriptor shape is
361363
// distributable.
362364
assert(tdescShape[0] == laneLayout[0] &&
363365
"tensor descriptor shape is not distributable");
364-
return VectorType::get({chunkSize}, getElementType());
366+
return VectorType::get({chunkSize}, elementType);
365367
}
366368

367369
// Case 2: block loads/stores
@@ -374,9 +376,21 @@ FailureOr<VectorType> TensorDescType::getDistributedVectorType() {
374376
tensorSize *= tdescDim;
375377
}
376378
// tensorSize must be adjusted for array_length.
377-
tensorSize *= getArrayLength();
379+
tensorSize *= tdescTy.getArrayLength();
378380

379-
return VectorType::get({tensorSize / sgSize}, getElementType());
381+
return VectorType::get({tensorSize / sgSize}, elementType);
382+
}
383+
384+
// Helper to get the distributed vector type for a given vector type according
385+
// to a given LayoutAttr.
386+
FailureOr<VectorType> getDistributedVectorType(VectorType originalType,
387+
LayoutAttr layout) {
388+
auto shape = originalType.getShape();
389+
auto helperTdescTy = xegpu::TensorDescType::get(
390+
shape, originalType.getElementType(),
391+
/*array_length=*/1, /*boundary_check=*/true,
392+
/*memory_space=*/xegpu::MemorySpace::Global, layout);
393+
return xegpu::getDistributedVectorType(helperTdescTy);
380394
}
381395

382396
} // namespace xegpu

mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,5 @@ add_mlir_dialect_library(MLIRXeGPUTransforms
1616
MLIRPass
1717
MLIRTransforms
1818
MLIRGPUDialect
19+
MLIRXeGPUDialect
1920
)

mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp

Lines changed: 21 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -696,9 +696,9 @@ class LayoutAttrAssignment {
696696
void assignToUsers(Value v, xegpu::LayoutAttr layout);
697697
xegpu::LayoutAttr getLayoutAttrForValue(Value v);
698698
LogicalResult resolveConflicts();
699-
function_ref<LayoutInfo(Value)>
700-
getAnalysisResult; // Callable to get the layout of a value based on the
701-
// layout propagation analysis.
699+
// Callable to get the layout of a value based on the layout propagation
700+
// analysis.
701+
function_ref<LayoutInfo(Value)> getAnalysisResult;
702702
Operation *top;
703703
};
704704

@@ -851,22 +851,6 @@ FailureOr<VectorType> getDistVecTypeBasedOnLaneLayout(xegpu::LayoutAttr layout,
851851
return VectorType::get(distributedShape, originalType.getElementType());
852852
}
853853

854-
/// Get the distributed vector type for a source vector type according to a
855-
/// xegpu::LayoutAttr.
856-
static VectorType getDistributedVectorType(xegpu::LayoutAttr layout,
857-
VectorType originalType) {
858-
auto shape = originalType.getShape();
859-
auto distVecTyOrFailure =
860-
xegpu::TensorDescType::get(shape, originalType.getElementType(),
861-
/*array_length=*/1, /*boundary_check=*/true,
862-
/*memory_space=*/xegpu::MemorySpace::Global,
863-
layout)
864-
.getDistributedVectorType();
865-
assert(llvm::succeeded(distVecTyOrFailure) &&
866-
"Failed to compute distributed vector type for the given vector type");
867-
return distVecTyOrFailure.value();
868-
}
869-
870854
/// Drop the layout attribute from the tensor descriptor type if layout is
871855
/// present.
872856
static xegpu::TensorDescType dropLayouts(xegpu::TensorDescType tensorDesc) {
@@ -1175,7 +1159,7 @@ struct StoreNdDistribution final : public gpu::WarpDistributionPattern {
11751159
/// supported by the store op. Type mismatch must be resolved using
11761160
/// appropriate cast op.
11771161
auto storeNdDistributedValueTyOrFailure =
1178-
storeOp.getTensorDescType().getDistributedVectorType();
1162+
xegpu::getDistributedVectorType(storeOp.getTensorDescType());
11791163
if (failed(storeNdDistributedValueTyOrFailure))
11801164
return rewriter.notifyMatchFailure(
11811165
storeOp, "Failed to get distributed vector type for the store op");
@@ -1263,7 +1247,7 @@ struct LoadNdDistribution final : public gpu::WarpDistributionPattern {
12631247
/// type.
12641248
rewriter.setInsertionPointAfter(newWarpOp);
12651249
auto loadNdDistValueTyOrFailure =
1266-
loadOp.getTensorDescType().getDistributedVectorType();
1250+
xegpu::getDistributedVectorType(loadOp.getTensorDescType());
12671251
if (failed(loadNdDistValueTyOrFailure))
12681252
return rewriter.notifyMatchFailure(
12691253
loadOp, "Failed to get distributed vector type for the load op");
@@ -1379,17 +1363,27 @@ struct DpasDistribution final : public gpu::WarpDistributionPattern {
13791363
gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
13801364
rewriter, subgroupOp, newYieldValues, newYieldTypes, newRetIndices);
13811365

1366+
FailureOr<VectorType> expectedDistLhsTyOrFailure =
1367+
xegpu::getDistributedVectorType(dpasOp.getLhsType(), layoutA);
1368+
FailureOr<VectorType> expectedDistRhsTyOrFailure =
1369+
xegpu::getDistributedVectorType(dpasOp.getRhsType(), layoutB);
1370+
FailureOr<VectorType> expectedDistResultTyOrFailure =
1371+
xegpu::getDistributedVectorType(dpasOp.getResultType(), layoutOut);
1372+
if (failed(expectedDistLhsTyOrFailure) ||
1373+
failed(expectedDistRhsTyOrFailure) ||
1374+
failed(expectedDistResultTyOrFailure))
1375+
return rewriter.notifyMatchFailure(
1376+
dpasOp,
1377+
"Failed to get distributed vector type for the dpas operands.");
13821378
// Create a new dpas op outside the warp op.
13831379
rewriter.setInsertionPointAfter(newWarpOp);
13841380
SmallVector<Value> newDpasOperands;
13851381
SmallVector<VectorType> newDpasOperandExpectedTypes;
1382+
13861383
/// Resolve the distributed types with the original types.
1387-
newDpasOperandExpectedTypes.push_back(
1388-
getDistributedVectorType(layoutA, dpasOp.getLhsType()));
1389-
newDpasOperandExpectedTypes.push_back(
1390-
getDistributedVectorType(layoutB, dpasOp.getRhsType()));
1391-
auto distributedResultTy =
1392-
getDistributedVectorType(layoutOut, dpasOp.getResultType());
1384+
newDpasOperandExpectedTypes.push_back(expectedDistLhsTyOrFailure.value());
1385+
newDpasOperandExpectedTypes.push_back(expectedDistRhsTyOrFailure.value());
1386+
auto distributedResultTy = expectedDistResultTyOrFailure.value();
13931387
if (dpasOp.getAcc())
13941388
newDpasOperandExpectedTypes.push_back(distributedResultTy);
13951389

0 commit comments

Comments
 (0)