Skip to content

Commit 9dddbbc

Browse files
committed
address review comments
1 parent a038699 commit 9dddbbc

File tree

3 files changed

+19
-23
lines changed

3 files changed

+19
-23
lines changed

mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -31,16 +31,16 @@ def SetDescLayoutOp : Op<Transform_Dialect, "xegpu.set_desc_layout", [
3131
}];
3232

3333
let arguments = (ins
34-
TransformHandleTypeInterface : $target,
35-
Variadic<TransformAnyParamTypeOrAnyHandle> : $sg_layout,
36-
Variadic<TransformAnyParamTypeOrAnyHandle> : $sg_data,
37-
Variadic<TransformAnyParamTypeOrAnyHandle> : $inst_data,
34+
TransformHandleTypeInterface:$target,
35+
Variadic<TransformAnyParamTypeOrAnyHandle>:$sg_layout,
36+
Variadic<TransformAnyParamTypeOrAnyHandle>:$sg_data,
37+
Variadic<TransformAnyParamTypeOrAnyHandle>:$inst_data,
3838
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_layout,
3939
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_data,
4040
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_inst_data
4141
);
4242

43-
let results = (outs TransformHandleTypeInterface : $transformed);
43+
let results = (outs TransformHandleTypeInterface:$transformed);
4444
let builders = [
4545
OpBuilder<(ins "Value":$target,
4646
"ArrayRef<OpFoldResult>":$mixedSgLayout,
@@ -92,11 +92,11 @@ def SetOpLayoutAttrOp : Op<Transform_Dialect, "xegpu.set_op_layout_attr", [
9292
is defined by the `sg_layout`, `sg_data` and optional `inst_data` attributes.
9393
}];
9494

95-
let arguments = (ins TransformHandleTypeInterface : $target,
96-
DefaultValuedOptionalAttr<I64Attr, "0"> : $index,
97-
Variadic<TransformAnyParamTypeOrAnyHandle> : $sg_layout,
98-
Variadic<TransformAnyParamTypeOrAnyHandle> : $sg_data,
99-
Variadic<TransformAnyParamTypeOrAnyHandle> : $inst_data,
95+
let arguments = (ins TransformHandleTypeInterface:$target,
96+
DefaultValuedOptionalAttr<I64Attr, "0">:$index,
97+
Variadic<TransformAnyParamTypeOrAnyHandle>:$sg_layout,
98+
Variadic<TransformAnyParamTypeOrAnyHandle>:$sg_data,
99+
Variadic<TransformAnyParamTypeOrAnyHandle>:$inst_data,
100100
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_layout,
101101
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_data,
102102
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_inst_data,

mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,7 @@ createLayoutAttr(MLIRContext *ctx, ArrayRef<int32_t> sgLayout,
9292

9393
/// Generate `xegpu::LayoutAttr` from op mixed layout values.
9494
DiagnosedSilenceableFailure
95-
getLayoutAttrFromOperands(transform::TransformRewriter &rewriter,
96-
transform::TransformState &state,
95+
getLayoutAttrFromOperands(MLIRContext *ctx, transform::TransformState &state,
9796
TransformOpInterface transformOp,
9897
ArrayRef<::mlir::OpFoldResult> mixedSgLayout,
9998
ArrayRef<::mlir::OpFoldResult> mixedSgData,
@@ -116,8 +115,7 @@ getLayoutAttrFromOperands(transform::TransformRewriter &rewriter,
116115
? std::nullopt
117116
: std::optional<ArrayRef<int32_t>>(instData);
118117

119-
layoutAttr =
120-
createLayoutAttr(rewriter.getContext(), sgLayout, sgData, maybeInstData);
118+
layoutAttr = createLayoutAttr(ctx, sgLayout, sgData, maybeInstData);
121119

122120
return DiagnosedSilenceableFailure::success();
123121
}
@@ -175,7 +173,7 @@ transform::SetDescLayoutOp::apply(transform::TransformRewriter &rewriter,
175173
Operation *target = *targetOps.begin();
176174

177175
xegpu::LayoutAttr layoutAttr = nullptr;
178-
auto status = getLayoutAttrFromOperands(rewriter, state, (*this),
176+
auto status = getLayoutAttrFromOperands(getContext(), state, (*this),
179177
getMixedSgLayout(), getMixedSgData(),
180178
getMixedInstData(), layoutAttr);
181179
if (!status.succeeded())
@@ -235,7 +233,6 @@ DiagnosedSilenceableFailure
235233
transform::SetOpLayoutAttrOp::apply(transform::TransformRewriter &rewriter,
236234
transform::TransformResults &results,
237235
transform::TransformState &state) {
238-
239236
auto targetOps = state.getPayloadOps(getTarget());
240237
if (!llvm::hasSingleElement(targetOps)) {
241238
return emitDefiniteFailure() << "Requires exactly one targetOp handle (got "
@@ -256,18 +253,17 @@ transform::SetOpLayoutAttrOp::apply(transform::TransformRewriter &rewriter,
256253
}
257254

258255
xegpu::LayoutAttr layoutAttr = nullptr;
259-
auto status = getLayoutAttrFromOperands(rewriter, state, (*this),
256+
auto status = getLayoutAttrFromOperands(getContext(), state, (*this),
260257
getMixedSgLayout(), getMixedSgData(),
261258
getMixedInstData(), layoutAttr);
262259
if (!status.succeeded())
263260
return status;
264261

265262
// Set layout attribute for the op result or operand
266-
if (resultTarget) {
263+
if (resultTarget)
267264
xegpu::setDistributeLayoutAttr(target->getResult(index), layoutAttr);
268-
} else {
265+
else
269266
xegpu::setDistributeLayoutAttr(target->getOpOperand(index), layoutAttr);
270-
}
271267
return DiagnosedSilenceableFailure::success();
272268
}
273269

mlir/python/mlir/dialects/transform/xegpu.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,9 +76,9 @@ def __init__(
7676
sg_layout: MixedValues,
7777
sg_data: MixedValues,
7878
*,
79-
inst_data: MixedValues = None,
80-
index: Union[int, Attribute] = None,
81-
result: Union[bool, Attribute] = None,
79+
inst_data: Optional[MixedValues] = None,
80+
index: Optional[Union[int, Attribute]] = None,
81+
result: Optional[Union[bool, Attribute]] = None,
8282
loc=None,
8383
ip=None,
8484
):

0 commit comments

Comments
 (0)