Skip to content

Commit 4c08b46

Browse files
committed
address review comments
1 parent 5fb499a commit 4c08b46

File tree

3 files changed

+14
-12
lines changed

3 files changed

+14
-12
lines changed

mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -31,16 +31,16 @@ def SetDescLayoutOp : Op<Transform_Dialect, "xegpu.set_desc_layout", [
3131
}];
3232

3333
let arguments = (ins
34-
TransformHandleTypeInterface : $target,
35-
Variadic<TransformAnyParamTypeOrAnyHandle> : $sg_layout,
36-
Variadic<TransformAnyParamTypeOrAnyHandle> : $sg_data,
37-
Variadic<TransformAnyParamTypeOrAnyHandle> : $inst_data,
34+
TransformHandleTypeInterface:$target,
35+
Variadic<TransformAnyParamTypeOrAnyHandle>:$sg_layout,
36+
Variadic<TransformAnyParamTypeOrAnyHandle>:$sg_data,
37+
Variadic<TransformAnyParamTypeOrAnyHandle>:$inst_data,
3838
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_layout,
3939
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_data,
4040
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_inst_data
4141
);
4242

43-
let results = (outs TransformHandleTypeInterface : $transformed);
43+
let results = (outs TransformHandleTypeInterface:$transformed);
4444
let builders = [
4545
OpBuilder<(ins "Value":$target,
4646
"ArrayRef<OpFoldResult>":$mixedSgLayout,
@@ -84,11 +84,13 @@ def SetGPULaunchThreadsOp
8484
TransformOpInterface
8585
]> {
8686

87-
let summary = "Set number of threads for a given gpu.launch operation";
88-
let description = "Set number of threads for a given `gpu.launch` operation.";
87+
let summary = "Set number of threads for a given gpu.launch operation"; let
88+
description = [{
89+
Overrides the x,y,z threads operands of a given `gpu.launch` operation in-place.
90+
}];
8991

90-
let arguments = (ins TransformHandleTypeInterface : $target,
91-
Variadic<TransformAnyParamTypeOrAnyHandle> : $threads,
92+
let arguments = (ins TransformHandleTypeInterface:$target,
93+
Variadic<TransformAnyParamTypeOrAnyHandle>:$threads,
9294
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_threads
9395
);
9496
let results = (outs);

mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,6 @@ DiagnosedSilenceableFailure
210210
transform::SetGPULaunchThreadsOp::apply(transform::TransformRewriter &rewriter,
211211
transform::TransformResults &results,
212212
transform::TransformState &state) {
213-
214213
auto targetOps = state.getPayloadOps(getTarget());
215214
if (!llvm::hasSingleElement(targetOps)) {
216215
return emitDefiniteFailure() << "Requires exactly one targetOp handle (got "
@@ -234,7 +233,8 @@ transform::SetGPULaunchThreadsOp::apply(transform::TransformRewriter &rewriter,
234233

235234
if (threads.size() != 3) {
236235
return emitSilenceableFailure(getLoc())
237-
<< "Expected threads to be a 3D vector";
236+
<< "Expected threads argument to consist of three values (got "
237+
<< threads.size() << ")";
238238
}
239239

240240
rewriter.setInsertionPoint(launchOp);

mlir/test/python/dialects/transform_xegpu_ext.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def setGPULaunchThreadsOp():
5656
sequence = transform.SequenceOp(
5757
transform.FailurePropagationMode.Propagate,
5858
[],
59-
transform.OperationType.get("gpu.lauch"),
59+
transform.OperationType.get("gpu.launch"),
6060
)
6161
with InsertionPoint(sequence.body):
6262
xegpu.SetGPULaunchThreadsOp(sequence.bodyTarget, threads=[8, 4, 1])

0 commit comments

Comments
 (0)