Skip to content

Commit a84014f

Browse files
committed
format
1 parent 0e34f36 commit a84014f

File tree

5 files changed

+33
-23
lines changed

5 files changed

+33
-23
lines changed

mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -67,20 +67,21 @@ std::string getLayoutName(const OpOperand &operand);
6767
/// Return the attribute name for the OpResult to attach DistributeLayoutAttr
6868
std::string getLayoutName(const OpResult result);
6969

70-
/// Retrieves the DistributeLayoutAttr associated with a given Value. For TensorDescType
71-
/// values, the DistributeLayoutAttr is extracted from the TensorDescType itself. For
72-
/// other values, it is obtained from the attributes of the defining operation.
73-
/// Returns nullptr if no DistributeLayoutAttr is found.
70+
/// Retrieves the DistributeLayoutAttr associated with a given Value. For
71+
/// TensorDescType values, the DistributeLayoutAttr is extracted from the
72+
/// TensorDescType itself. For other values, it is obtained from the attributes
73+
/// of the defining operation. Returns nullptr if no DistributeLayoutAttr is
74+
/// found.
7475
DistributeLayoutAttr getDistributeLayoutAttr(const Value value);
7576

7677
template <typename AttrTy>
7778
AttrTy getDistributeLayoutAttrOfType(const Value value) {
7879
return dyn_cast_if_present<AttrTy>(getDistributeLayoutAttr(value));
7980
}
8081

81-
/// Retrieves the DistributeLayoutAttr associated with a given OpOperand. It will
82-
/// first check the operand_layout_{id} of the owner operation. If not found,
83-
/// it will check the operand itself and its defining op.
82+
/// Retrieves the DistributeLayoutAttr associated with a given OpOperand. It
83+
/// will first check the operand_layout_{id} of the owner operation. If not
84+
/// found, it will check the operand itself and its defining op.
8485
DistributeLayoutAttr getDistributeLayoutAttr(const OpOperand &opr);
8586

8687
template <typename AttrTy>
@@ -94,8 +95,8 @@ template <typename T,
9495
std::is_same_v<T, OpResult>>>
9596
void removeLayoutAttr(const T &operandOrResult);
9697

97-
/// Removes the DistributeLayoutAttr for each OpOperand and OpResult of the given
98-
/// operation if they exist. If the operation contains regions, it is also
98+
/// Removes the DistributeLayoutAttr for each OpOperand and OpResult of the
99+
/// given operation if they exist. If the operation contains regions, it is also
99100
/// applied recursively to the contained operations
100101
void removeLayoutAttrs(Operation *op);
101102

@@ -107,9 +108,9 @@ template <typename T,
107108
void setDistributeLayoutAttr(const T &operandOrResult,
108109
const DistributeLayoutAttr layout);
109110

110-
/// Set the DistributeLayoutAttr for each OpOperand and OpResult of the given operation.
111-
/// If the operation contains regions, it is also applied recursively to the
112-
/// contained operations
111+
/// Set the DistributeLayoutAttr for each OpOperand and OpResult of the given
112+
/// operation. If the operation contains regions, it is also applied recursively
113+
/// to the contained operations
113114
void setDistributeLayoutAttrs(
114115
Operation *op, function_ref<DistributeLayoutAttr(Value)> getLayoutImpl);
115116

mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,8 +147,8 @@ bool XeGPUDialect::isEvenlyDistributable(llvm::ArrayRef<int64_t> shape,
147147
auto instShape = maybeInstShape.value();
148148

149149
// check LaneLayout and LaneData
150-
auto maybeLaneShape =
151-
tryDistribute(instShape, attr.getLaneLayoutAsInt(), attr.getLaneDataAsInt(), false);
150+
auto maybeLaneShape = tryDistribute(instShape, attr.getLaneLayoutAsInt(),
151+
attr.getLaneDataAsInt(), false);
152152
return maybeLaneShape.has_value();
153153
}
154154

mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,8 @@ XeGPUBlockingPass::getTileShape(const T &operandOrResult) const {
140140
else
141141
value = (Value)operandOrResult;
142142

143-
xegpu::DistributeLayoutAttr layout = xegpu::getDistributeLayoutAttr(operandOrResult);
143+
xegpu::DistributeLayoutAttr layout =
144+
xegpu::getDistributeLayoutAttr(operandOrResult);
144145
if (layout && layout.isForSubgroup()) {
145146
if (auto inst_data = layout.getInstDataAsInt())
146147
return inst_data.value();
@@ -204,12 +205,14 @@ bool XeGPUBlockingPass::needsUnroll(Operation *op) const {
204205
// skip the op if any of its operands or results has workgroup level layouts
205206
bool hasWgLayoutOperands =
206207
llvm::any_of(op->getOpOperands(), [](OpOperand &opr) {
207-
xegpu::DistributeLayoutAttr layout = xegpu::getDistributeLayoutAttr(opr);
208+
xegpu::DistributeLayoutAttr layout =
209+
xegpu::getDistributeLayoutAttr(opr);
208210
return layout && layout.isForWorkgroup();
209211
});
210212
bool hasWgLayoutResults =
211213
llvm::any_of(op->getOpResults(), [](OpResult result) {
212-
xegpu::DistributeLayoutAttr layout = xegpu::getDistributeLayoutAttr(result);
214+
xegpu::DistributeLayoutAttr layout =
215+
xegpu::getDistributeLayoutAttr(result);
213216
return layout && layout.isForWorkgroup();
214217
});
215218
if (hasWgLayoutOperands || hasWgLayoutResults) {

mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -470,7 +470,8 @@ struct WgToSgVectorBroadcastOp
470470
VectorType resultType = op.getResult().getType();
471471
ArrayRef<int64_t> wgShape = resultType.getShape();
472472

473-
xegpu::DistributeLayoutAttr layout = xegpu::getDistributeLayoutAttr(op.getResult());
473+
xegpu::DistributeLayoutAttr layout =
474+
xegpu::getDistributeLayoutAttr(op.getResult());
474475
if (!layout || !layout.isForWorkgroup())
475476
return failure();
476477

@@ -535,7 +536,8 @@ struct WgToSgElementwiseOp : public ConversionPattern {
535536

536537
ArrayRef<int64_t> wgShape = resultType.getShape();
537538

538-
xegpu::DistributeLayoutAttr layout = xegpu::getDistributeLayoutAttr(op->getResult(0));
539+
xegpu::DistributeLayoutAttr layout =
540+
xegpu::getDistributeLayoutAttr(op->getResult(0));
539541
if (!layout || !layout.isForWorkgroup())
540542
return failure();
541543

@@ -737,7 +739,8 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
737739
if (!vecAttr || !vecAttr.isSplat() || !vecType)
738740
return failure();
739741

740-
xegpu::DistributeLayoutAttr layout = xegpu::getDistributeLayoutAttr(op.getResult());
742+
xegpu::DistributeLayoutAttr layout =
743+
xegpu::getDistributeLayoutAttr(op.getResult());
741744
if (!layout || !layout.isForWorkgroup())
742745
return failure();
743746

@@ -980,7 +983,8 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
980983
}
981984
}
982985

983-
xegpu::DistributeLayoutAttr layout = xegpu::getDistributeLayoutAttr(op->getResult(0));
986+
xegpu::DistributeLayoutAttr layout =
987+
xegpu::getDistributeLayoutAttr(op->getResult(0));
984988
return isLegal(layout);
985989
});
986990

mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,8 @@ xegpu::DistributeLayoutAttr xegpu::getDistributeLayoutAttr(const Value value) {
151151
return nullptr;
152152
}
153153

154-
xegpu::DistributeLayoutAttr xegpu::getDistributeLayoutAttr(const OpOperand &opr) {
154+
xegpu::DistributeLayoutAttr
155+
xegpu::getDistributeLayoutAttr(const OpOperand &opr) {
155156
Operation *op = opr.getOwner();
156157
std::string layoutName = xegpu::getLayoutName(opr);
157158
if (op->hasAttr(layoutName))
@@ -307,7 +308,8 @@ void xegpu::doSCFStructuralTypeConversionWithTensorType(
307308
if (!inputTy || !resultTy)
308309
return WalkResult::skip();
309310

310-
xegpu::DistributeLayoutAttr layout = xegpu::getDistributeLayoutAttr(input);
311+
xegpu::DistributeLayoutAttr layout =
312+
xegpu::getDistributeLayoutAttr(input);
311313
if (!layout)
312314
return WalkResult::skip();
313315

0 commit comments

Comments
 (0)