Skip to content

Commit 9ae490c

Browse files
committed
cleanup getNumSubgroups
1 parent ce07282 commit 9ae490c

File tree

1 file changed

+6
-15
lines changed

1 file changed

+6
-15
lines changed

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,12 @@ def DistributeLayoutAttrInterface: AttrInterface<"DistributeLayoutAttrInterface"
190190
"getRank">,
191191
InterfaceMethod<"Get the num of effective subgroups",
192192
"int64_t",
193-
"getNumSubgroups">,
193+
"getNumSubgroups", (ins), [{
194+
std::optional<SmallVector<int64_t>> sgLayout = llvm::cast<ConcreteAttr>(tablegen_opaque_val).getSgLayoutAsInt();
195+
if (sgLayout.has_value())
196+
return computeProduct(*sgLayout);
197+
return 0;
198+
}], [{}]>,
194199
InterfaceMethod<"Get the SgLayout field of the attribute as integer array",
195200
"std::optional<SmallVector<int64_t>>",
196201
"getSgLayoutAsInt">,
@@ -355,13 +360,6 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [DistributeLayoutAttrInterf
355360
return 0;
356361
}
357362

358-
int64_t getNumSubgroups() {
359-
std::optional<SmallVector<int64_t>> sgLayout = getSgLayoutAsInt();
360-
if (sgLayout.has_value())
361-
return computeProduct(*sgLayout);
362-
return 0;
363-
}
364-
365363
LayoutAttr dropSgLayoutAndData() {
366364
// avoid every field of the attribute is nullptr, which may lead to segment fault
367365
if (!getInstData() && !getLaneLayout())
@@ -466,13 +464,6 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttrInterface
466464
return parent.isForSubgroup();
467465
}
468466

469-
int64_t getNumSubgroups() {
470-
std::optional<SmallVector<int64_t>> sgLayout = getSgLayoutAsInt();
471-
if (sgLayout.has_value())
472-
return computeProduct(*sgLayout);
473-
return 0;
474-
}
475-
476467
/// Returns the SgLayout of the attribute, computed by applying
477468
/// the slice dimensions to the underlying LayoutAttr.
478469
std::optional<SmallVector<int64_t>> getSgLayoutAsInt() const {

0 commit comments

Comments
 (0)