@@ -174,6 +174,17 @@ def LayoutTrait: AttrInterface<"LayoutTrait"> {
174174 let description = [{
175175 Common trait for all XeGPU layouts.
176176 }];
177+
178+ let methods = [
179+ InterfaceMethod<"Get the effective sg layout",
180+ "std::optional<llvm::SmallVector<int>>",
181+ "getEffectiveSgLayout">,
182+ InterfaceMethod<"Get the effective sg data",
183+ "std::optional<llvm::SmallVector<int>>",
184+ "getEffectiveSgData">,
185+ ];
186+
187+
177188}
178189
179190def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [LayoutTrait]> {
@@ -331,6 +342,18 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [LayoutTrait]> {
331342 return LayoutAttr::get(getContext(), getSgLayout(), getSgData(), nullptr,
332343 getLaneLayout(), getLaneData(), getOrder());
333344 }
345+
346+ std::optional<llvm::SmallVector<int32_t>> getEffectiveSgLayout() const {
347+ if (DenseI32ArrayAttr layout = getSgLayout())
348+ return llvm::to_vector(layout.asArrayRef());
349+ return std::nullopt;
350+ }
351+
352+ std::optional<llvm::SmallVector<int32_t>> getEffectiveSgData() const {
353+ if (DenseI32ArrayAttr data = getSgData())
354+ return llvm::to_vector(data.asArrayRef());
355+ return std::nullopt;
356+ }
334357 }];
335358
336359 let assemblyFormat = "`<` struct(params) `>`";
@@ -351,11 +374,40 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [LayoutTrait]> {
351374 }];
352375
353376 let parameters = (ins
354- "Attribute ": $parent,
377+ "xegpu::LayoutAttr ": $parent,
355378 "DenseI64ArrayAttr": $dims
356379 );
357380
381+ let extraClassDeclaration = [{
382+ std::optional<llvm::SmallVector<int32_t>> getEffectiveSgLayout() const {
383+ if (DenseI32ArrayAttr layout = getParent().getSgLayout()) {
384+ llvm::ArrayRef<int64_t> dims = getDims().asArrayRef();
385+ llvm::SmallVector<int32_t> result;
386+ for (auto [i, v]: llvm::enumerate(layout.asArrayRef())) {
387+ if (!llvm::is_contained(dims, i))
388+ result.push_back(v);
389+ }
390+ return result;
391+ }
392+ return std::nullopt;
393+ }
394+ std::optional<llvm::SmallVector<int32_t>> getEffectiveSgData() const {
395+ if (DenseI32ArrayAttr data = getParent().getSgData()) {
396+ llvm::ArrayRef<int64_t> dims = getDims().asArrayRef();
397+ llvm::SmallVector<int32_t> result;
398+ for (auto [i, v]: llvm::enumerate(data.asArrayRef())) {
399+ if (!llvm::is_contained(dims, i))
400+ result.push_back(v);
401+ }
402+ return result;
403+ }
404+ return std::nullopt;
405+
406+ }
407+ }];
408+
358409 let assemblyFormat = "`<` $parent `,` `dim` `=` $dims `>`";
410+ let genVerifyDecl = 1;
359411}
360412
361413#endif // MLIR_DIALECT_XEGPU_IR_XEGPUATTRS_TD
0 commit comments