@@ -183,15 +183,20 @@ def LayoutTrait: AttrInterface<"LayoutTrait"> {
183183
184184 let methods = [
185185 InterfaceMethod<"Get the effective sg layout",
186- "std::optional<llvm:: SmallVector<int >>",
186+ "std::optional<SmallVector<int64_t >>",
187187 "getEffectiveSgLayout">,
188188 InterfaceMethod<"Get the effective sg data",
189- "std::optional<llvm:: SmallVector<int >>",
189+ "std::optional<SmallVector<int64_t >>",
190190 "getEffectiveSgData">,
191191 InterfaceMethod<"Delinearize the Subgroup Id",
192192 "FailureOr<SmallVector<Value>>",
193193 "delinearizeSubgroupId",
194- (ins "Value":$linearId, "Location":$loc, "OpBuilder &": $builder)>
194+ (ins "OpBuilder &": $builder, "Location":$loc, "Value":$linearId)>,
195+
196+ InterfaceMethod<"Get the local offset to be accessed by the given subgroup Id",
197+ "FailureOr<SmallVector<SmallVector<Value>>>",
198+ "getOffsets",
199+ (ins "OpBuilder &": $builder, "Location":$loc, "Value":$linearId, "ArrayRef<int64_t>":$shape)>
195200 ];
196201}
197202
@@ -351,20 +356,23 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [LayoutTrait]> {
351356 getLaneLayout(), getLaneData(), getOrder());
352357 }
353358
354- std::optional<llvm:: SmallVector<int32_t >> getEffectiveSgLayout() const {
359+ std::optional<SmallVector<int64_t >> getEffectiveSgLayout() const {
355360 if (DenseI32ArrayAttr layout = getSgLayout())
356- return llvm::to_vector (layout.asArrayRef());
361+ return llvm::to_vector_of<int64_t> (layout.asArrayRef());
357362 return std::nullopt;
358363 }
359364
360- std::optional<llvm:: SmallVector<int32_t >> getEffectiveSgData() const {
365+ std::optional<SmallVector<int64_t >> getEffectiveSgData() const {
361366 if (DenseI32ArrayAttr data = getSgData())
362- return llvm::to_vector (data.asArrayRef());
367+ return llvm::to_vector_of<int64_t> (data.asArrayRef());
363368 return std::nullopt;
364369 }
365370
366371 FailureOr<SmallVector<Value>>
367- delinearizeSubgroupId(Value linearId, Location loc, OpBuilder &builder);
372+ delinearizeSubgroupId(OpBuilder &builder, Location loc, Value linearId);
373+
374+ FailureOr<SmallVector<SmallVector<Value>>>
375+ getOffsets(OpBuilder &builder, Location loc, Value linearId, ArrayRef<int64_t> shape);
368376
369377 }];
370378
@@ -401,24 +409,6 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [LayoutTrait]> {
401409 );
402410
403411 let extraClassDeclaration = [{
404- std::optional<llvm::SmallVector<int32_t>> getEffectiveSgLayout() const {
405- if (DenseI32ArrayAttr layout = getParent().getSgLayout()) {
406- llvm::ArrayRef<int64_t> dims = getDims().asArrayRef();
407- return XeGPUDialect::dropDims(layout.asArrayRef(), dims);
408- }
409- return std::nullopt;
410- }
411-
412- std::optional<llvm::SmallVector<int32_t>> getEffectiveSgData() const {
413- if (DenseI32ArrayAttr data = getParent().getSgData()) {
414- llvm::ArrayRef<int64_t> dims = getDims().asArrayRef();
415- return XeGPUDialect::dropDims(data.asArrayRef(), dims);
416- }
417- return std::nullopt;
418- }
419-
420- FailureOr<llvm::SmallVector<Value>>
421- delinearizeSubgroupId(Value linearId, Location loc, OpBuilder &builder);
422412
423413 DenseI32ArrayAttr getOrder() const {
424414 return getParent().getOrder();
@@ -431,6 +421,29 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [LayoutTrait]> {
431421 bool isSgLayout() const {
432422 return getParent().isSgLayout();
433423 }
424+
425+ std::optional<SmallVector<int64_t>> getEffectiveSgLayout() const {
426+ if (auto layout = getParent().getEffectiveSgLayout()) {
427+ ArrayRef<int64_t> dims = getDims().asArrayRef();
428+ return XeGPUDialect::dropDims(llvm::ArrayRef<int64_t>(*layout), dims);
429+ }
430+ return std::nullopt;
431+ }
432+
433+ std::optional<SmallVector<int64_t>> getEffectiveSgData() const {
434+ if (auto data = getParent().getEffectiveSgData()) {
435+ ArrayRef<int64_t> dims = getDims().asArrayRef();
436+ return XeGPUDialect::dropDims(llvm::ArrayRef<int64_t>(*data), dims);
437+ }
438+ return std::nullopt;
439+ }
440+
441+ FailureOr<SmallVector<Value>>
442+ delinearizeSubgroupId(OpBuilder &builder, Location loc, Value linearId);
443+
444+ FailureOr<SmallVector<SmallVector<Value>>>
445+ getOffsets(OpBuilder &builder, Location loc, Value linearId, ArrayRef<int64_t> shape);
446+
434447 }];
435448
436449 let assemblyFormat = "`<` $parent `,` `dims` `=` $dims `>`";
0 commit comments