Skip to content

Commit 2027cfc

Browse files
committed
add verifier and interface
1 parent 3959f9e commit 2027cfc

File tree

2 files changed

+74
-1
lines changed

2 files changed

+74
-1
lines changed

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,17 @@ def LayoutTrait: AttrInterface<"LayoutTrait"> {
174174
let description = [{
175175
Common trait for all XeGPU layouts.
176176
}];
177+
178+
let methods = [
179+
InterfaceMethod<"Get the effective sg layout",
180+
"std::optional<llvm::SmallVector<int>>",
181+
"getEffectiveSgLayout">,
182+
InterfaceMethod<"Get the effective sg data",
183+
"std::optional<llvm::SmallVector<int>>",
184+
"getEffectiveSgData">,
185+
];
186+
187+
177188
}
178189

179190
def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [LayoutTrait]> {
@@ -331,6 +342,18 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [LayoutTrait]> {
331342
return LayoutAttr::get(getContext(), getSgLayout(), getSgData(), nullptr,
332343
getLaneLayout(), getLaneData(), getOrder());
333344
}
345+
346+
std::optional<llvm::SmallVector<int32_t>> getEffectiveSgLayout() const {
347+
if (DenseI32ArrayAttr layout = getSgLayout())
348+
return llvm::to_vector(layout.asArrayRef());
349+
return std::nullopt;
350+
}
351+
352+
std::optional<llvm::SmallVector<int32_t>> getEffectiveSgData() const {
353+
if (DenseI32ArrayAttr data = getSgData())
354+
return llvm::to_vector(data.asArrayRef());
355+
return std::nullopt;
356+
}
334357
}];
335358

336359
let assemblyFormat = "`<` struct(params) `>`";
@@ -351,11 +374,40 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [LayoutTrait]> {
351374
}];
352375

353376
let parameters = (ins
354-
"Attribute": $parent,
377+
"xegpu::LayoutAttr": $parent,
355378
"DenseI64ArrayAttr": $dims
356379
);
357380

381+
let extraClassDeclaration = [{
382+
std::optional<llvm::SmallVector<int32_t>> getEffectiveSgLayout() const {
383+
if (DenseI32ArrayAttr layout = getParent().getSgLayout()) {
384+
llvm::ArrayRef<int64_t> dims = getDims().asArrayRef();
385+
llvm::SmallVector<int32_t> result;
386+
for (auto [i, v]: llvm::enumerate(layout.asArrayRef())) {
387+
if (!llvm::is_contained(dims, i))
388+
result.push_back(v);
389+
}
390+
return result;
391+
}
392+
return std::nullopt;
393+
}
394+
std::optional<llvm::SmallVector<int32_t>> getEffectiveSgData() const {
395+
if (DenseI32ArrayAttr data = getParent().getSgData()) {
396+
llvm::ArrayRef<int64_t> dims = getDims().asArrayRef();
397+
llvm::SmallVector<int32_t> result;
398+
for (auto [i, v]: llvm::enumerate(data.asArrayRef())) {
399+
if (!llvm::is_contained(dims, i))
400+
result.push_back(v);
401+
}
402+
return result;
403+
}
404+
return std::nullopt;
405+
406+
}
407+
}];
408+
358409
let assemblyFormat = "`<` $parent `,` `dim` `=` $dims `>`";
410+
let genVerifyDecl = 1;
359411
}
360412

361413
#endif // MLIR_DIALECT_XEGPU_IR_XEGPUATTRS_TD

mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,27 @@ LayoutAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
206206
return success();
207207
}
208208

209+
//===----------------------------------------------------------------------===//
210+
// XeGPU_SliceAttr
211+
//===----------------------------------------------------------------------===//
212+
LogicalResult
213+
SliceAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
214+
xegpu::LayoutAttr parent, DenseI64ArrayAttr dims) {
215+
if (!parent || !dims)
216+
return emitError() << "expected parent layout and dims attribute";
217+
218+
int rank = parent.getRank();
219+
// check every element in dims is unique and smaller than rank
220+
llvm::SmallDenseSet<int64_t> seen;
221+
for (int64_t dim : dims.asArrayRef()) {
222+
if (dim >= rank)
223+
return emitError() << "invalid dim: " << dim;
224+
if (!seen.insert(dim).second)
225+
return emitError() << "repeated dim: " << dim;
226+
}
227+
return success();
228+
}
229+
209230
//===----------------------------------------------------------------------===//
210231
// XeGPU_TensorDescType
211232
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)