Skip to content

Commit 9e799d6

Browse files
committed
Fix builders
1 parent e219003 commit 9e799d6

File tree

3 files changed

+22
-41
lines changed

3 files changed

+22
-41
lines changed

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -142,11 +142,7 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface
142142
let builders = [
143143
OpBuilder<(ins "Type": $tdesc, "TypedValue<MemRefType>": $source)>,
144144

145-
OpBuilder<(ins "Type": $tdesc, "TypedValue<MemRefType> ": $source,
146-
"llvm::ArrayRef<OpFoldResult>": $shape,
147-
"llvm::ArrayRef<OpFoldResult>": $strides)>,
148-
149-
OpBuilder<(ins "Type": $tdesc, "TypedValue<IntegerType> ": $source,
145+
OpBuilder<(ins "Type": $tdesc, "Value ": $source,
150146
"llvm::ArrayRef<OpFoldResult>": $shape,
151147
"llvm::ArrayRef<OpFoldResult>": $strides)>,
152148

mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp

Lines changed: 19 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -156,48 +156,37 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
156156
}
157157

158158
void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
159-
Type tdesc, TypedValue<MemRefType> source,
159+
Type tdesc, Value source,
160160
llvm::ArrayRef<OpFoldResult> shape,
161161
llvm::ArrayRef<OpFoldResult> strides) {
162-
assert(shape.size() && strides.size() && shape.size() == strides.size() &&
163-
"Shape and strides must be present and of equal size for ui64 "
164-
"initialization.");
162+
Type srcTy = source.getType();
163+
assert((isa<IntegerType, MemRefType>(srcTy)) &&
164+
"Source has to be either int or memref.");
165165

166-
llvm::SmallVector<int64_t> staticShape;
167-
llvm::SmallVector<int64_t> staticStrides;
168166
llvm::SmallVector<Value> dynamicShape;
169167
llvm::SmallVector<Value> dynamicStrides;
170168

171-
dispatchIndexOpFoldResults(shape, dynamicShape, staticShape);
172-
dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
173-
174-
auto staticShapeAttr = builder.getDenseI64ArrayAttr(staticShape);
175-
auto staticStridesAttr = builder.getDenseI64ArrayAttr(staticStrides);
176-
177-
build(builder, state, tdesc, source, ValueRange({}), dynamicShape,
178-
dynamicStrides, builder.getDenseI64ArrayAttr({}), staticShapeAttr,
179-
staticStridesAttr);
180-
}
181-
182-
void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
183-
Type tdesc, TypedValue<IntegerType> source,
184-
llvm::ArrayRef<OpFoldResult> shape,
185-
llvm::ArrayRef<OpFoldResult> strides) {
186-
assert(shape.size() && strides.size() && shape.size() == strides.size() &&
187-
"Shape and strides must be present and of equal size for ui64 "
188-
"initialization.");
189-
190169
llvm::SmallVector<int64_t> staticShape;
191170
llvm::SmallVector<int64_t> staticStrides;
192-
llvm::SmallVector<Value> dynamicShape;
193-
llvm::SmallVector<Value> dynamicStrides;
194171

195172
dispatchIndexOpFoldResults(shape, dynamicShape, staticShape);
196173
dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
197174

198175
auto staticShapeAttr = builder.getDenseI64ArrayAttr(staticShape);
199176
auto staticStridesAttr = builder.getDenseI64ArrayAttr(staticStrides);
200177

178+
if (auto memrefTy = dyn_cast<MemRefType>(srcTy)) {
179+
auto memrefShape = memrefTy.getShape();
180+
auto [memrefStrides, _] = memrefTy.getStridesAndOffset();
181+
182+
// if shape and strides are from Memref, we don't need attributes for them
183+
// to keep the IR print clean.
184+
if (staticShape == memrefShape && staticStrides == memrefStrides) {
185+
staticShapeAttr = DenseI64ArrayAttr();
186+
staticStridesAttr = DenseI64ArrayAttr();
187+
}
188+
}
189+
201190
build(builder, state, tdesc, source, ValueRange({}), dynamicShape,
202191
dynamicStrides, builder.getDenseI64ArrayAttr({}), staticShapeAttr,
203192
staticStridesAttr);
@@ -357,13 +346,10 @@ ParseResult parseOptionalDynamicIndexList(
357346
void printOptionalDynamicIndexList(OpAsmPrinter &printer, Operation *op,
358347
OperandRange values,
359348
DenseI64ArrayAttr integers) {
360-
361-
if (!integers)
349+
if (!integers || integers.empty())
362350
return;
363-
364-
return printDynamicIndexList(printer, op, values, integers,
365-
/*scalableFlags=*/{}, {},
366-
AsmParser::Delimiter::Square);
351+
printDynamicIndexList(printer, op, values, integers,
352+
/*scalableFlags=*/{}, {}, AsmParser::Delimiter::Square);
367353
}
368354
//===----------------------------------------------------------------------===//
369355
// XeGPU_PrefetchNdOp

mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -240,9 +240,8 @@ struct WgToSgCreateNdOpNoOffset
240240
SmallVector<Value> newCreateNdOps;
241241
for (int i = 0; i < count; ++i) {
242242
auto newOp = xegpu::CreateNdDescOp::create(
243-
rewriter, loc, newTdescTy, op.getSource(), ValueRange(), ValueRange(),
244-
ValueRange(), DenseI64ArrayAttr(), DenseI64ArrayAttr(),
245-
DenseI64ArrayAttr());
243+
rewriter, loc, newTdescTy, op.getSource(), op.getMixedSizes(),
244+
op.getMixedStrides());
246245
newCreateNdOps.push_back(newOp);
247246
}
248247
rewriter.replaceOpWithMultiple(op, {newCreateNdOps});

0 commit comments

Comments
 (0)