Skip to content

Commit 456b01c

Browse files
committed
refine
1 parent 1e353fa commit 456b01c

File tree

3 files changed

+51
-34
lines changed

3 files changed

+51
-34
lines changed

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,28 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout"> {
253253
return $_get($_ctxt, sg_layout, sg_data, inst_data,
254254
DenseI32ArrayAttr::get($_ctxt, lane_layout),
255255
DenseI32ArrayAttr::get($_ctxt, lane_data), order);
256+
}]>,
257+
AttrBuilder<(ins "llvm::ArrayRef<int>": $lane_layout,
258+
"llvm::ArrayRef<int>": $lane_data,
259+
"llvm::ArrayRef<int>": $order),
260+
[{
261+
auto sg_layout = DenseI32ArrayAttr();
262+
auto sg_data = DenseI32ArrayAttr();
263+
auto inst_data = DenseI32ArrayAttr();
264+
return $_get($_ctxt, sg_layout, sg_data, inst_data,
265+
DenseI32ArrayAttr::get($_ctxt, lane_layout),
266+
DenseI32ArrayAttr::get($_ctxt, lane_data),
267+
DenseI32ArrayAttr::get($_ctxt, order));
268+
}]>,
269+
AttrBuilder<(ins "DenseI32ArrayAttr": $lane_layout,
270+
"DenseI32ArrayAttr": $lane_data,
271+
"DenseI32ArrayAttr": $order),
272+
[{
273+
auto sg_layout = DenseI32ArrayAttr();
274+
auto sg_data = DenseI32ArrayAttr();
275+
auto inst_data = DenseI32ArrayAttr();
276+
return $_get($_ctxt, sg_layout, sg_data, inst_data,
277+
lane_layout, lane_data, order);
256278
}]>
257279
];
258280

@@ -262,7 +284,7 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout"> {
262284
}
263285

264286
bool isSgLayout() {
265-
return getSgLayout() == nullptr && getLaneLayout() != nullptr;
287+
return !isWgLayout();
266288
}
267289

268290
int64_t getRank() {
@@ -274,6 +296,17 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout"> {
274296
return attr.size();
275297
return 0;
276298
}
299+
300+
LayoutAttr dropSgLayoutAndData() {
301+
return LayoutAttr::get(getContext(), nullptr, nullptr, getInstData(),
302+
getLaneLayout(), getLaneData(), getOrder());
303+
}
304+
305+
LayoutAttr dropInstData() {
306+
return LayoutAttr::get(getContext(), getSgLayout(), getSgData(), nullptr,
307+
getLaneLayout(), getLaneData(), getOrder());
308+
}
309+
277310
}];
278311

279312
let assemblyFormat = "`<` struct(params) `>`";

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -142,12 +142,7 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface
142142
OpBuilder<(ins "Type": $tdesc, "TypedValue<MemRefType>": $source,
143143
"llvm::ArrayRef<OpFoldResult>": $offsets)>,
144144

145-
OpBuilder<(ins "Type": $tdesc, "TypedValue<MemRefType> ": $source,
146-
"llvm::ArrayRef<OpFoldResult>": $offsets,
147-
"llvm::ArrayRef<OpFoldResult>": $shape,
148-
"llvm::ArrayRef<OpFoldResult>": $strides)>,
149-
150-
OpBuilder<(ins "Type": $tdesc, "TypedValue<IntegerType> ": $source,
145+
OpBuilder<(ins "Type": $tdesc, "Value": $source,
151146
"llvm::ArrayRef<OpFoldResult>": $offsets,
152147
"llvm::ArrayRef<OpFoldResult>": $shape,
153148
"llvm::ArrayRef<OpFoldResult>": $strides)>

mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp

Lines changed: 16 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -141,46 +141,24 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
141141
}
142142

143143
void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
144-
Type tdesc, TypedValue<MemRefType> source,
144+
Type tdesc, Value source,
145145
llvm::ArrayRef<OpFoldResult> offsets,
146146
llvm::ArrayRef<OpFoldResult> shape,
147147
llvm::ArrayRef<OpFoldResult> strides) {
148148
assert(shape.size() && offsets.size() && strides.size() &&
149149
shape.size() == strides.size() && shape.size() == offsets.size());
150150

151-
llvm::SmallVector<int64_t> staticOffsets;
152-
llvm::SmallVector<int64_t> staticShape;
153-
llvm::SmallVector<int64_t> staticStrides;
151+
auto intTy = dyn_cast<IntegerType>(source.getType());
152+
auto memrefTy = dyn_cast<MemRefType>(source.getType());
153+
assert(intTy || memrefTy && "Source has to be either int or memref.");
154+
154155
llvm::SmallVector<Value> dynamicOffsets;
155156
llvm::SmallVector<Value> dynamicShape;
156157
llvm::SmallVector<Value> dynamicStrides;
157158

158-
dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
159-
dispatchIndexOpFoldResults(shape, dynamicShape, staticShape);
160-
dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
161-
162-
auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
163-
auto staticShapeAttr = builder.getDenseI64ArrayAttr(staticShape);
164-
auto staticStridesAttr = builder.getDenseI64ArrayAttr(staticStrides);
165-
166-
build(builder, state, tdesc, source, dynamicOffsets, dynamicShape,
167-
dynamicStrides, staticOffsetsAttr, staticShapeAttr, staticStridesAttr);
168-
}
169-
170-
void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
171-
Type tdesc, TypedValue<IntegerType> source,
172-
llvm::ArrayRef<OpFoldResult> offsets,
173-
llvm::ArrayRef<OpFoldResult> shape,
174-
llvm::ArrayRef<OpFoldResult> strides) {
175-
assert(shape.size() && offsets.size() && strides.size() &&
176-
shape.size() == strides.size() && shape.size() == offsets.size());
177-
178159
llvm::SmallVector<int64_t> staticOffsets;
179160
llvm::SmallVector<int64_t> staticShape;
180161
llvm::SmallVector<int64_t> staticStrides;
181-
llvm::SmallVector<Value> dynamicOffsets;
182-
llvm::SmallVector<Value> dynamicShape;
183-
llvm::SmallVector<Value> dynamicStrides;
184162

185163
dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
186164
dispatchIndexOpFoldResults(shape, dynamicShape, staticShape);
@@ -190,6 +168,17 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
190168
auto staticShapeAttr = builder.getDenseI64ArrayAttr(staticShape);
191169
auto staticStridesAttr = builder.getDenseI64ArrayAttr(staticStrides);
192170

171+
if (memrefTy) {
172+
auto memrefShape = memrefTy.getShape();
173+
auto [memrefStrides, offset] = memrefTy.getStridesAndOffset();
174+
175+
// if shape and strides are from Memref, we don't need attributes for them
176+
if (staticShape == memrefShape && staticStrides == memrefStrides) {
177+
staticShapeAttr = DenseI64ArrayAttr();
178+
staticStridesAttr = DenseI64ArrayAttr();
179+
}
180+
}
181+
193182
build(builder, state, tdesc, source, dynamicOffsets, dynamicShape,
194183
dynamicStrides, staticOffsetsAttr, staticShapeAttr, staticStridesAttr);
195184
}

0 commit comments

Comments
 (0)