Skip to content

Commit c42d7f7

Browse files
committed
address review comments
1 parent 95cadd2 commit c42d7f7

File tree

3 files changed

+9
-7
lines changed

3 files changed

+9
-7
lines changed

mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -167,11 +167,13 @@ def SetGPULaunchThreadsOp
167167
TransformOpInterface
168168
]> {
169169

170-
let summary = "Set number of threads for a given gpu.launch operation";
171-
let description = "Set number of threads for a given `gpu.launch` operation.";
170+
let summary = "Set number of threads for a given gpu.launch operation"; let
171+
description = [{
172+
Overrides the x,y,z threads operands of a given `gpu.launch` operation in-place.
173+
}];
172174

173-
let arguments = (ins TransformHandleTypeInterface : $target,
174-
Variadic<TransformAnyParamTypeOrAnyHandle> : $threads,
175+
let arguments = (ins TransformHandleTypeInterface:$target,
176+
Variadic<TransformAnyParamTypeOrAnyHandle>:$threads,
175177
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_threads
176178
);
177179
let results = (outs);

mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -358,7 +358,6 @@ DiagnosedSilenceableFailure
358358
transform::SetGPULaunchThreadsOp::apply(transform::TransformRewriter &rewriter,
359359
transform::TransformResults &results,
360360
transform::TransformState &state) {
361-
362361
auto targetOps = state.getPayloadOps(getTarget());
363362
if (!llvm::hasSingleElement(targetOps)) {
364363
return emitDefiniteFailure() << "Requires exactly one targetOp handle (got "
@@ -382,7 +381,8 @@ transform::SetGPULaunchThreadsOp::apply(transform::TransformRewriter &rewriter,
382381

383382
if (threads.size() != 3) {
384383
return emitSilenceableFailure(getLoc())
385-
<< "Expected threads to be a 3D vector";
384+
<< "Expected threads argument to consist of three values (got "
385+
<< threads.size() << ")";
386386
}
387387

388388
rewriter.setInsertionPoint(launchOp);

mlir/test/python/dialects/transform_xegpu_ext.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def setGPULaunchThreadsOp():
120120
sequence = transform.SequenceOp(
121121
transform.FailurePropagationMode.Propagate,
122122
[],
123-
transform.OperationType.get("gpu.lauch"),
123+
transform.OperationType.get("gpu.launch"),
124124
)
125125
with InsertionPoint(sequence.body):
126126
xegpu.SetGPULaunchThreadsOp(sequence.bodyTarget, threads=[8, 4, 1])

0 commit comments

Comments
 (0)