Skip to content

Commit e25937a

Browse files
committed
more nit comments
1 parent 86b0ee1 commit e25937a

File tree

2 files changed

+6
-12
lines changed

2 files changed

+6
-12
lines changed

mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h"
10-
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
1110
#include "mlir/Dialect/SCF/IR/SCF.h"
1211
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
1312
#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
@@ -34,7 +33,6 @@ class XeGPUTransformDialectExtension
3433
void XeGPUTransformDialectExtension::init() {
3534
declareGeneratedDialect<scf::SCFDialect>();
3635
declareGeneratedDialect<arith::ArithDialect>();
37-
declareGeneratedDialect<gpu::GPUDialect>();
3836
declareGeneratedDialect<xegpu::XeGPUDialect>();
3937

4038
registerTransformOps<
@@ -173,33 +171,29 @@ DiagnosedSilenceableFailure
173171
transform::SetDescLayoutOp::apply(transform::TransformRewriter &rewriter,
174172
transform::TransformResults &results,
175173
transform::TransformState &state) {
176-
177174
auto targetOps = state.getPayloadOps(getTarget());
178175
if (!llvm::hasSingleElement(targetOps)) {
179176
return emitDefiniteFailure() << "requires exactly one targetOp handle (got "
180177
<< llvm::range_size(targetOps) << ")";
181178
}
182179
Operation *target = *targetOps.begin();
183180

184-
auto transformOp = cast<TransformOpInterface>(getOperation());
185-
186181
SmallVector<int32_t> sgLayout;
187182
DiagnosedSilenceableFailure status =
188-
convertMixedValuesToInt(state, transformOp, sgLayout, getMixedSgLayout());
183+
convertMixedValuesToInt(state, (*this), sgLayout, getMixedSgLayout());
189184
if (!status.succeeded())
190185
return status;
191186

192187
SmallVector<int32_t> sgData;
193-
status =
194-
convertMixedValuesToInt(state, transformOp, sgData, getMixedSgData());
188+
status = convertMixedValuesToInt(state, (*this), sgData, getMixedSgData());
195189
if (!status.succeeded())
196190
return status;
197191
auto maybeSgData =
198192
sgData.empty() ? std::nullopt : std::optional<ArrayRef<int32_t>>(sgData);
199193

200194
SmallVector<int32_t> instData;
201195
status =
202-
convertMixedValuesToInt(state, transformOp, instData, getMixedInstData());
196+
convertMixedValuesToInt(state, (*this), instData, getMixedInstData());
203197
if (!status.succeeded())
204198
return status;
205199
auto maybeInstData = instData.empty()

mlir/python/mlir/dialects/transform/xegpu.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
except ImportError as e:
1818
raise RuntimeError("Error loading imports from extension module") from e
1919

20-
from typing import Union
20+
from typing import Union, Optional
2121

2222

2323
@_ods_cext.register_operation(_Dialect, replace=True)
@@ -29,8 +29,8 @@ def __init__(
2929
target: Union[Operation, Value],
3030
sg_layout: MixedValues,
3131
*,
32-
sg_data: MixedValues = None,
33-
inst_data: MixedValues = None,
32+
sg_data: Optional[MixedValues] = None,
33+
inst_data: Optional[MixedValues] = None,
3434
loc=None,
3535
ip=None,
3636
):

0 commit comments

Comments
 (0)