Skip to content

Commit d7e0f31

Browse files
committed
[Intel] Use 'CTAEncodingAttr' after '49b7472'
Signed-off-by: Anatoly Myachev <[email protected]>
1 parent b82b6a8 commit d7e0f31

File tree

13 files changed

+86
-116
lines changed

13 files changed

+86
-116
lines changed

python/test/unit/intel/test_block_io.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -53,18 +53,14 @@ def __str__(self):
5353

5454
class BlockedLayout:
5555

56-
def __init__(self, size_per_thread, threads_per_warp, warps_per_cta, order, ctas_per_cga=[1, 1],
57-
cta_split_num=[1, 1], cta_order=[0, 1]):
56+
def __init__(self, size_per_thread, threads_per_warp, warps_per_cta, order):
5857
self.sz_per_thread = size_per_thread
5958
self.threads_per_warp = threads_per_warp
6059
self.warps_per_cta = warps_per_cta
6160
self.order = order
62-
self.ctas_per_cga = ctas_per_cga
63-
self.cta_split_num = cta_split_num
64-
self.cta_order = cta_order
6561

6662
def __str__(self):
67-
return f"#ttg.blocked<{{sizePerThread={self.sz_per_thread}, threadsPerWarp={self.threads_per_warp}, warpsPerCTA={self.warps_per_cta}, order={self.order}, CTAsPerCGA={self.ctas_per_cga}, CTASplitNum={self.cta_split_num}, CTAOrder={self.cta_order}}}>"
63+
return f"#ttg.blocked<{{sizePerThread={self.sz_per_thread}, threadsPerWarp={self.threads_per_warp}, warpsPerCTA={self.warps_per_cta}, order={self.order}}}>"
6864

6965

7066
def warps_per_cta(layout):
@@ -75,7 +71,7 @@ def warps_per_cta(layout):
7571

7672

7773
layouts = [
78-
BlockedLayout([1, 1], [2, 16], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]),
74+
BlockedLayout([1, 1], [2, 16], [4, 1], [1, 0]),
7975
# DPAS layout
8076
DpasLayout(repeatCount=8, systolic_depth=8, execution_size=16, ops_per_chan=4, threads_per_warp=16,
8177
warps_per_cta=[1, 4], rep_cluster=[1, 2]),
@@ -110,8 +106,7 @@ def warps_per_cta(layout):
110106
parent=DpasLayout(repeatCount=8, systolic_depth=8, execution_size=16, ops_per_chan=1, threads_per_warp=32,
111107
warps_per_cta=[2, 2], rep_cluster=[1, 1]), op_idx=1, k_width=1),
112108
# Slice layout
113-
SliceLayout(dim=1, parent=BlockedLayout([1, 4, 1], [2, 1, 16], [2, 1, 2], [2, 1, 0], [1, 1, 1], [1, 1, 1],
114-
[0, 1, 2])),
109+
SliceLayout(dim=1, parent=BlockedLayout([1, 4, 1], [2, 1, 16], [2, 1, 2], [2, 1, 0])),
115110
]
116111

117112

python/test/unit/intel/test_core.py

Lines changed: 44 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -62,33 +62,26 @@ def __str__(self):
6262

6363
class BlockedLayout:
6464

65-
def __init__(self, size_per_thread, threads_per_warp, warps_per_cta, order, ctas_per_cga=[1, 1],
66-
cta_split_num=[1, 1], cta_order=[0, 1]):
65+
def __init__(self, size_per_thread, threads_per_warp, warps_per_cta, order):
6766
self.sz_per_thread = size_per_thread
6867
self.threads_per_warp = threads_per_warp
6968
self.warps_per_cta = warps_per_cta
7069
self.order = order
71-
self.ctas_per_cga = ctas_per_cga
72-
self.cta_split_num = cta_split_num
73-
self.cta_order = cta_order
7470

7571
def __str__(self):
76-
return f"#{GPU_DIALECT}.blocked<{{sizePerThread={self.sz_per_thread}, threadsPerWarp={self.threads_per_warp}, warpsPerCTA={self.warps_per_cta}, order={self.order}, CTAsPerCGA={self.ctas_per_cga}, CTASplitNum={self.cta_split_num}, CTAOrder={self.cta_order}}}>"
72+
return f"#{GPU_DIALECT}.blocked<{{sizePerThread={self.sz_per_thread}, threadsPerWarp={self.threads_per_warp}, warpsPerCTA={self.warps_per_cta}, order={self.order}}}>"
7773

7874

7975
class SwizzledSharedLayout:
8076

81-
def __init__(self, vec, per_phase, max_phase, order, ctas_per_cga, cta_split_num, cta_order):
77+
def __init__(self, vec, per_phase, max_phase, order):
8278
self.vec = vec
8379
self.per_phase = per_phase
8480
self.max_phase = max_phase
8581
self.order = order
86-
self.ctas_per_cga = ctas_per_cga
87-
self.cta_split_num = cta_split_num
88-
self.cta_order = cta_order
8982

9083
def __str__(self):
91-
return f"#{GPU_DIALECT}.swizzled_shared<{{vec={self.vec}, perPhase={self.per_phase}, maxPhase={self.max_phase}, order={self.order}, CTAsPerCGA={self.ctas_per_cga}, CTASplitNum={self.cta_split_num}, CTAOrder={self.cta_order}}}>"
84+
return f"#{GPU_DIALECT}.swizzled_shared<{{vec={self.vec}, perPhase={self.per_phase}, maxPhase={self.max_phase}, order={self.order}}}>"
9285

9386

9487
class PaddedSharedLayout:
@@ -172,17 +165,17 @@ def get_reduce_input(dtype_str, shape):
172165

173166

174167
scan_layouts = [
175-
BlockedLayout([1, 4], [4, THREADS_PER_WARP // 4], [4, 1], [0, 1], [1, 1], [1, 1], [0, 1]),
176-
BlockedLayout([1, 4], [8, THREADS_PER_WARP // 8], [4, 1], [0, 1], [1, 1], [1, 1], [0, 1]),
177-
BlockedLayout([4, 1], [4, THREADS_PER_WARP // 4], [1, 4], [0, 1], [1, 1], [1, 1], [0, 1]),
178-
BlockedLayout([2, 2], [4, THREADS_PER_WARP // 4], [2, 2], [0, 1], [1, 1], [1, 1], [0, 1]),
179-
BlockedLayout([2, 2], [8, THREADS_PER_WARP // 8], [2, 2], [0, 1], [1, 1], [1, 1], [0, 1]),
180-
BlockedLayout([1, 4], [4, THREADS_PER_WARP // 4], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]),
181-
BlockedLayout([1, 4], [8, THREADS_PER_WARP // 8], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]),
182-
BlockedLayout([4, 1], [4, THREADS_PER_WARP // 4], [1, 4], [1, 0], [1, 1], [1, 1], [0, 1]),
183-
BlockedLayout([2, 2], [4, THREADS_PER_WARP // 4], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]),
184-
BlockedLayout([2, 2], [8, THREADS_PER_WARP // 8], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]),
185-
BlockedLayout([1, 2], [1, THREADS_PER_WARP // 1], [1, 4], [1, 0], [1, 1], [1, 1], [0, 1]),
168+
BlockedLayout([1, 4], [4, THREADS_PER_WARP // 4], [4, 1], [0, 1]),
169+
BlockedLayout([1, 4], [8, THREADS_PER_WARP // 8], [4, 1], [0, 1]),
170+
BlockedLayout([4, 1], [4, THREADS_PER_WARP // 4], [1, 4], [0, 1]),
171+
BlockedLayout([2, 2], [4, THREADS_PER_WARP // 4], [2, 2], [0, 1]),
172+
BlockedLayout([2, 2], [8, THREADS_PER_WARP // 8], [2, 2], [0, 1]),
173+
BlockedLayout([1, 4], [4, THREADS_PER_WARP // 4], [4, 1], [1, 0]),
174+
BlockedLayout([1, 4], [8, THREADS_PER_WARP // 8], [4, 1], [1, 0]),
175+
BlockedLayout([4, 1], [4, THREADS_PER_WARP // 4], [1, 4], [1, 0]),
176+
BlockedLayout([2, 2], [4, THREADS_PER_WARP // 4], [2, 2], [1, 0]),
177+
BlockedLayout([2, 2], [8, THREADS_PER_WARP // 8], [2, 2], [1, 0]),
178+
BlockedLayout([1, 2], [1, THREADS_PER_WARP // 1], [1, 4], [1, 0]),
186179
]
187180

188181

@@ -254,8 +247,8 @@ def test_scan_layouts(M, N, src_layout, axis, add_overflow_check, device, tmp_pa
254247

255248

256249
layouts = [
257-
BlockedLayout([1, 4], [8, THREADS_PER_WARP // 8], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]),
258-
BlockedLayout([1, 4], [8, THREADS_PER_WARP // 8], [4, 1], [0, 1], [1, 1], [1, 1], [0, 1]),
250+
BlockedLayout([1, 4], [8, THREADS_PER_WARP // 8], [4, 1], [1, 0]),
251+
BlockedLayout([1, 4], [8, THREADS_PER_WARP // 8], [4, 1], [0, 1]),
259252
DpasLayout(repeatCount=8, systolic_depth=8, execution_size=8, ops_per_chan=1, threads_per_warp=32,
260253
warps_per_cta=[4, 1], rep_cluster=[1, 1]),
261254
DpasLayout(repeatCount=8, systolic_depth=8, execution_size=16, ops_per_chan=2, threads_per_warp=32,
@@ -305,8 +298,8 @@ def test_reduce_layouts(M, N, src_layout, axis, epilogue_kind, dtype_str, add_ov
305298
store_range = "%7" if axis == 0 else "%1"
306299
warps = warps_per_cta(src_layout, [M, N])
307300
num_warps = int(np.prod(warps))
308-
blocked = BlockedLayout([1, 1], [32, THREADS_PER_WARP // 32], [4, num_warps // 4], [0, 1], [1, 1], [1, 1], [0, 1])
309-
one_d_layout = BlockedLayout([1], [THREADS_PER_WARP], [num_warps], [0], [1], [1], [0])
301+
blocked = BlockedLayout([1, 1], [32, THREADS_PER_WARP // 32], [4, num_warps // 4], [0, 1])
302+
one_d_layout = BlockedLayout([1], [THREADS_PER_WARP], [num_warps], [0])
310303

311304
expanded_shape = f"1x{N}" if axis == 0 else f"{M}x1"
312305
other_axis = 1 - axis
@@ -397,8 +390,8 @@ def test_reduce_layouts(M, N, src_layout, axis, epilogue_kind, dtype_str, add_ov
397390

398391

399392
layouts = [
400-
BlockedLayout([1, 4], [1, THREADS_PER_WARP], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]),
401-
BlockedLayout([1, 4], [1, THREADS_PER_WARP], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]),
393+
BlockedLayout([1, 4], [1, THREADS_PER_WARP], [4, 1], [1, 0]),
394+
BlockedLayout([1, 4], [1, THREADS_PER_WARP], [2, 2], [1, 0]),
402395
DpasLayout(repeatCount=8, systolic_depth=8, execution_size=8, ops_per_chan=1, threads_per_warp=32,
403396
warps_per_cta=[4, 1], rep_cluster=[1, 1]),
404397
]
@@ -443,8 +436,8 @@ def test_store_op(M, src_layout, device, tmp_path: pathlib.Path):
443436

444437

445438
layouts = [
446-
BlockedLayout([1, 4], [1, THREADS_PER_WARP], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]),
447-
BlockedLayout([1, 4], [1, THREADS_PER_WARP], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]),
439+
BlockedLayout([1, 4], [1, THREADS_PER_WARP], [4, 1], [1, 0]),
440+
BlockedLayout([1, 4], [1, THREADS_PER_WARP], [2, 2], [1, 0]),
448441
DpasLayout(repeatCount=8, systolic_depth=8, execution_size=8, ops_per_chan=1, threads_per_warp=32,
449442
warps_per_cta=[4, 1], rep_cluster=[1, 1])
450443
]
@@ -532,10 +525,10 @@ def test_convert1d_bool(M, src_layout, dst_layout, src_dim, dst_dim, device, tmp
532525

533526

534527
layouts = [
535-
BlockedLayout([1, 4], [1, THREADS_PER_WARP], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]),
536-
BlockedLayout([1, 4], [1, THREADS_PER_WARP], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]),
537-
BlockedLayout([1, 4], [THREADS_PER_WARP // 32, 32], [1, 4], [1, 0], [1, 1], [1, 1], [0, 1]),
538-
BlockedLayout([1, 4], [8, THREADS_PER_WARP // 8], [2, 2], [0, 1], [1, 1], [1, 1], [0, 1])
528+
BlockedLayout([1, 4], [1, THREADS_PER_WARP], [4, 1], [1, 0]),
529+
BlockedLayout([1, 4], [1, THREADS_PER_WARP], [2, 2], [1, 0]),
530+
BlockedLayout([1, 4], [THREADS_PER_WARP // 32, 32], [1, 4], [1, 0]),
531+
BlockedLayout([1, 4], [8, THREADS_PER_WARP // 8], [2, 2], [0, 1])
539532
]
540533

541534

@@ -611,8 +604,8 @@ def test_chain_reduce(M, N, src_layout, op, device, first_axis, tmp_path: pathli
611604
# TODO: backend should be tested separately
612605

613606
layouts = [
614-
BlockedLayout([1, 1], [THREADS_PER_WARP, 1], [2, 2], [0, 1], [1, 1], [1, 1], [0, 1]),
615-
BlockedLayout([1, 16], [8, THREADS_PER_WARP // 8], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]),
607+
BlockedLayout([1, 1], [THREADS_PER_WARP, 1], [2, 2], [0, 1]),
608+
BlockedLayout([1, 16], [8, THREADS_PER_WARP // 8], [4, 1], [1, 0]),
616609
DpasLayout(repeatCount=8, systolic_depth=8, execution_size=8, ops_per_chan=1, threads_per_warp=32,
617610
warps_per_cta=[4, 1], rep_cluster=[1, 1]),
618611
DpasLayout(repeatCount=2, systolic_depth=8, execution_size=8, ops_per_chan=1, threads_per_warp=32,
@@ -621,10 +614,10 @@ def test_chain_reduce(M, N, src_layout, op, device, first_axis, tmp_path: pathli
621614

622615
intermediate_layouts = [
623616
None,
624-
SwizzledSharedLayout(1, 1, 1, [0, 1], [1, 1], [1, 1], [0, 1]),
625-
SwizzledSharedLayout(1, 1, 1, [1, 0], [1, 1], [1, 1], [0, 1]),
626-
SwizzledSharedLayout(4, 2, 4, [1, 0], [1, 1], [1, 1], [0, 1]),
627-
SwizzledSharedLayout(2, 2, 4, [1, 0], [1, 1], [1, 1], [0, 1]),
617+
SwizzledSharedLayout(1, 1, 1, [0, 1]),
618+
SwizzledSharedLayout(1, 1, 1, [1, 0]),
619+
SwizzledSharedLayout(4, 2, 4, [1, 0]),
620+
SwizzledSharedLayout(2, 2, 4, [1, 0]),
628621
]
629622

630623

@@ -736,15 +729,15 @@ def test_convert2d(M, N, src_layout, interm_layout, dst_layout, dtype, device, t
736729

737730

738731
layouts_3d = [
739-
BlockedLayout([4, 4, 1], [1, 8, THREADS_PER_WARP // 8], [2, 2, 1], [2, 1, 0], [1, 1, 1], [1, 1, 1], [0, 1, 2]),
740-
BlockedLayout([1, 1, 4], [8, THREADS_PER_WARP // 8, 1], [2, 1, 2], [1, 2, 0], [1, 1, 1], [1, 1, 1], [0, 1, 2]),
732+
BlockedLayout([4, 4, 1], [1, 8, THREADS_PER_WARP // 8], [2, 2, 1], [2, 1, 0]),
733+
BlockedLayout([1, 1, 4], [8, THREADS_PER_WARP // 8, 1], [2, 1, 2], [1, 2, 0]),
741734
]
742735

743736
shared_layouts_3d = [
744-
SwizzledSharedLayout(1, 1, 1, [2, 1, 0], [1, 1, 1], [1, 1, 1], [0, 1, 2]),
745-
SwizzledSharedLayout(4, 2, 4, [1, 2, 0], [1, 1, 1], [1, 1, 1], [0, 1, 2]),
746-
SwizzledSharedLayout(8, 2, 4, [0, 2, 1], [1, 1, 1], [1, 1, 1], [0, 1, 2]),
747-
SwizzledSharedLayout(4, 2, 1, [2, 0, 1], [1, 1, 1], [1, 1, 1], [0, 1, 2]),
737+
SwizzledSharedLayout(1, 1, 1, [2, 1, 0]),
738+
SwizzledSharedLayout(4, 2, 4, [1, 2, 0]),
739+
SwizzledSharedLayout(8, 2, 4, [0, 2, 1]),
740+
SwizzledSharedLayout(4, 2, 1, [2, 0, 1]),
748741
]
749742

750743

@@ -841,9 +834,9 @@ def test_local_load_store(M, N, K, dist_layout, shared_layout, device, tmp_path:
841834
]
842835

843836
shared_layouts = [
844-
SwizzledSharedLayout(4, 2, 4, [0, 1], [1, 1], [1, 1], [0, 1]),
845-
SwizzledSharedLayout(8, 1, 8, [1, 0], [1, 1], [1, 1], [0, 1]),
846-
SwizzledSharedLayout(16, 1, 16, [1, 0], [1, 1], [1, 1], [0, 1]),
837+
SwizzledSharedLayout(4, 2, 4, [0, 1]),
838+
SwizzledSharedLayout(8, 1, 8, [1, 0]),
839+
SwizzledSharedLayout(16, 1, 16, [1, 0]),
847840
]
848841

849842

@@ -855,7 +848,7 @@ def test_split_subview(M, N, M_tile_size, N_tile_size, device, tmp_path: pathlib
855848
num_repeats_N = triton.cdiv(N, N_tile_size)
856849

857850
ir = f"""
858-
#blocked = #ttg.blocked<{{sizePerThread=[1, 8], threadsPerWarp=[{num_rows_per_warp}, 4], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}}>
851+
#blocked = #ttg.blocked<{{sizePerThread=[1, 8], threadsPerWarp=[{num_rows_per_warp}, 4], warpsPerCTA=[4, 1], order=[1, 0]}}>
859852
#shared = #ttg.swizzled_shared<{{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0]}}>
860853
#smem = #ttg.shared_memory
861854
@@ -989,7 +982,7 @@ def test_local_load_store_dot(M, N, dtype, dist_layout, shared_layout, device, t
989982
]
990983

991984
shared_layouts = [
992-
SwizzledSharedLayout(8, 1, 1, [1, 0], [1, 1], [1, 1], [0, 1]),
985+
SwizzledSharedLayout(8, 1, 1, [1, 0]),
993986
]
994987

995988

python/triton/experimental/gluon/language/intel/_layouts.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,3 +84,7 @@ def __hash__(self):
8484
self.threads_per_warp,
8585
tuple(self.cta_order),
8686
))
87+
88+
@property
89+
def rank(self):
90+
return len(self.warps_per_cta)

test/Conversion/intel/tritongpu_to_gen.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1389,7 +1389,7 @@ tt.func @test_get_program_id(%a: tensor<32x!tt.ptr<i32>, #blocked0>) {
13891389

13901390
// -----
13911391

1392-
#blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [4], CTASplitNum = [1], CTAOrder = [0]}>
1392+
#blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CGALayout = [[0], [0]]}>
13931393
module attributes {"ttg.num-ctas" = 4 : i32, "ttg.num-warps" = 4 : i32} {
13941394
// CHECK-LABEL: test_get_program_id
13951395
tt.func @test_get_program_id(%a: tensor<32x!tt.ptr<i32>, #blocked0>) {

test/TritonIntelGPU/tritongpu_reduce_op_lowering.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
// COM: Tests reduction when threads_per_warp < num_warps.
44

5-
#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [64], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
5+
#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [64], order = [0]}>
66
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 64 : i32, "ttg.threads-per-warp" = 32 : i32} {
77
// CHECK-LABEL: reduce_problem_size_64_threads_per_warp_32
88
tt.func @reduce_problem_size_64_threads_per_warp_32(%f : tensor<2048xi32, #blocked>) {

third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -310,7 +310,7 @@ def Subgroup2DBlockEncodingAttr : DistributedEncoding<"Subgroup2DBlockEncoding",
310310
let parameters = (
311311
ins
312312
ArrayRefParameter<"unsigned">:$warpsPerCTA,
313-
"CTALayoutAttr":$CTALayout,
313+
"CTAEncodingAttr":$CTALayout,
314314
ArrayRefParameter<"unsigned">:$instrShape,
315315
"unsigned":$numBlocks,
316316
ArrayRefParameter<"unsigned">:$order,

third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp

Lines changed: 20 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -146,22 +146,12 @@ DpasEncodingAttr::getRepOrderForOperand(OpIdx opIdx) const {
146146
return getOrderForDotOperand(unsigned(opIdx), rank, /*kMajor*/ true);
147147
}
148148

149-
SmallVector<unsigned> DpasEncodingAttr::getCTASplitNum() const {
149+
CTAEncodingAttr DpasEncodingAttr::getCTALayout() const {
150150
size_t rank = getWarpsPerCTA().size();
151-
SmallVector<unsigned> res(rank, 1);
152-
return res;
153-
}
154-
155-
SmallVector<unsigned> DpasEncodingAttr::getCTAOrder() const {
156-
size_t rank = getWarpsPerCTA().size();
157-
auto res = llvm::to_vector(llvm::reverse(llvm::seq<unsigned>(rank)));
158-
return res;
159-
}
160-
161-
SmallVector<unsigned> DpasEncodingAttr::getCTAsPerCGA() const {
162-
size_t rank = getWarpsPerCTA().size();
163-
SmallVector<unsigned> res(rank, 1);
164-
return res;
151+
SmallVector<unsigned> CTAsPerCGA(rank, 1);
152+
auto CTAOrder = llvm::to_vector(llvm::reverse(llvm::seq<unsigned>(rank)));
153+
return CTAEncodingAttr::fromSplitParams(getContext(), CTAsPerCGA, CTAsPerCGA,
154+
CTAOrder);
165155
}
166156

167157
SmallVector<int64_t>
@@ -441,16 +431,8 @@ LinearLayout WarpEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
441431
llvm::report_fatal_error("NYI. WarpEncodingAttr::toLinearLayout");
442432
}
443433

444-
SmallVector<unsigned> WarpEncodingAttr::getCTAsPerCGA() const {
445-
llvm::report_fatal_error("NYI. WarpEncodingAttr::getCTAsPerCGA");
446-
}
447-
448-
SmallVector<unsigned> WarpEncodingAttr::getCTAOrder() const {
449-
llvm::report_fatal_error("NYI. WarpEncodingAttr::getCTAOrder");
450-
}
451-
452-
SmallVector<unsigned> WarpEncodingAttr::getCTASplitNum() const {
453-
llvm::report_fatal_error("NYI. WarpEncodingAttr::getCTASplitNum");
434+
CTAEncodingAttr WarpEncodingAttr::getCTALayout() const {
435+
llvm::report_fatal_error("NYI. WarpEncodingAttr::getCTALayout");
454436
}
455437

456438
Attribute WarpEncodingAttr::parse(AsmParser &parser, Type type) {
@@ -506,16 +488,16 @@ void WarpEncodingAttr::print(mlir::AsmPrinter &printer) const {
506488
//===----------------------------------------------------------------------===//
507489

508490
namespace {
509-
std::optional<CTALayoutAttr> getCTALayoutOrError(
491+
std::optional<CTAEncodingAttr> getCTALayoutOrError(
510492
AsmParser &parser, std::optional<SmallVector<unsigned>> CTAsPerCGA,
511493
std::optional<SmallVector<unsigned>> CTASplitNum,
512494
std::optional<SmallVector<unsigned>> CTAOrder, unsigned rank) {
513495
if (CTAsPerCGA && CTASplitNum && CTAOrder) {
514-
return CTALayoutAttr::get(parser.getContext(), *CTAsPerCGA, *CTASplitNum,
515-
*CTAOrder);
496+
return CTAEncodingAttr::fromSplitParams(parser.getContext(), *CTAsPerCGA,
497+
*CTASplitNum, *CTAOrder);
516498
}
517499
if (!CTAsPerCGA && !CTASplitNum && !CTAOrder) {
518-
return CTALayoutAttr::getDefault(parser.getContext(), rank);
500+
return CTAEncodingAttr::getDefault(parser.getContext(), rank);
519501
}
520502
parser.emitError(parser.getNameLoc(), "CTAsPerCGA, CTASplitNum, and CTAOrder "
521503
"must all be present or all be absent");
@@ -524,8 +506,8 @@ std::optional<CTALayoutAttr> getCTALayoutOrError(
524506

525507
// Print the CTALayout if it's not equal to the default.
526508
void maybePrintCTALayout(mlir::MLIRContext *context, mlir::AsmPrinter &printer,
527-
CTALayoutAttr layout, unsigned rank) {
528-
if (layout != CTALayoutAttr::getDefault(context, rank)) {
509+
CTAEncodingAttr layout, unsigned rank) {
510+
if (layout != CTAEncodingAttr::getDefault(context, rank)) {
529511
printer << ", CTAsPerCGA = [" << ArrayRef(layout.getCTAsPerCGA()) << "]"
530512
<< ", CTASplitNum = [" << ArrayRef(layout.getCTASplitNum()) << "]"
531513
<< ", CTAOrder = [" << ArrayRef(layout.getCTAOrder()) << "]";
@@ -536,7 +518,7 @@ void maybePrintCTALayout(mlir::MLIRContext *context, mlir::AsmPrinter &printer,
536518

537519
LogicalResult Subgroup2DBlockEncodingAttr::verify(
538520
function_ref<InFlightDiagnostic()> emitError,
539-
ArrayRef<unsigned> warpsPerCTA, CTALayoutAttr CTALayout,
521+
ArrayRef<unsigned> warpsPerCTA, CTAEncodingAttr CTALayout,
540522
ArrayRef<unsigned> instrShape, unsigned numBlocks, ArrayRef<unsigned> order,
541523
unsigned kWidth, unsigned threadsPerWarp) {
542524
if (instrShape.size() != 2) {
@@ -621,7 +603,7 @@ Attribute Subgroup2DBlockEncodingAttr::parse(AsmParser &parser, Type type) {
621603
}
622604
}
623605

624-
std::optional<CTALayoutAttr> CTALayout = getCTALayoutOrError(
606+
std::optional<CTAEncodingAttr> CTALayout = getCTALayoutOrError(
625607
parser, CTAsPerCGA, CTASplitNum, CTAOrder, /*rank=*/warpsPerCTA.size());
626608
if (!CTALayout.has_value())
627609
return {};
@@ -898,8 +880,10 @@ struct TritonIntelGPUInferLayoutInterface
898880
// Cowardly refuse to handle encodings with multiple CTAs. CTAsPerCGA
899881
// should be like the other fields in blocked encoding, but I'm not sure how
900882
// to handle CTASplitNum.
901-
if (!all_of(src.getCTAsPerCGA(), [](int32_t x) { return x == 1; }) ||
902-
!all_of(src.getCTASplitNum(), [](int32_t x) { return x == 1; })) {
883+
if (!all_of(src.getCTALayout().getCTAsPerCGA(),
884+
[](int32_t x) { return x == 1; }) ||
885+
!all_of(src.getCTALayout().getCTASplitNum(),
886+
[](int32_t x) { return x == 1; })) {
903887
return failure();
904888
}
905889

@@ -1074,7 +1058,7 @@ struct TritonIntelGPUInferLayoutInterface
10741058
auto dstOrder = inversePermutation(dstInvOrder);
10751059

10761060
// CTALayout can be all 1's because we bailed on multi-CTA layouts above.
1077-
auto CTALayout = CTALayoutAttr::get(
1061+
auto CTALayout = CTAEncodingAttr::fromSplitParams(
10781062
src.getContext(),
10791063
/*CTAsPerCGA=*/SmallVector<unsigned>(dstShape.size(), 1),
10801064
/*CTASplitNum=*/SmallVector<unsigned>(dstShape.size(), 1),

0 commit comments

Comments
 (0)