Skip to content

Commit 35f83ef

Browse files
committed
[Intel] Use 'CTAEncodingAttr' after '49b7472'
Signed-off-by: Anatoly Myachev <[email protected]>
1 parent b82b6a8 commit 35f83ef

File tree

14 files changed

+92
-118
lines changed

14 files changed

+92
-118
lines changed

python/test/unit/intel/test_block_io.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -53,18 +53,14 @@ def __str__(self):
5353

5454
class BlockedLayout:
5555

56-
def __init__(self, size_per_thread, threads_per_warp, warps_per_cta, order, ctas_per_cga=[1, 1],
57-
cta_split_num=[1, 1], cta_order=[0, 1]):
56+
def __init__(self, size_per_thread, threads_per_warp, warps_per_cta, order):
5857
self.sz_per_thread = size_per_thread
5958
self.threads_per_warp = threads_per_warp
6059
self.warps_per_cta = warps_per_cta
6160
self.order = order
62-
self.ctas_per_cga = ctas_per_cga
63-
self.cta_split_num = cta_split_num
64-
self.cta_order = cta_order
6561

6662
def __str__(self):
67-
return f"#ttg.blocked<{{sizePerThread={self.sz_per_thread}, threadsPerWarp={self.threads_per_warp}, warpsPerCTA={self.warps_per_cta}, order={self.order}, CTAsPerCGA={self.ctas_per_cga}, CTASplitNum={self.cta_split_num}, CTAOrder={self.cta_order}}}>"
63+
return f"#ttg.blocked<{{sizePerThread={self.sz_per_thread}, threadsPerWarp={self.threads_per_warp}, warpsPerCTA={self.warps_per_cta}, order={self.order}}}>"
6864

6965

7066
def warps_per_cta(layout):
@@ -75,7 +71,7 @@ def warps_per_cta(layout):
7571

7672

7773
layouts = [
78-
BlockedLayout([1, 1], [2, 16], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]),
74+
BlockedLayout([1, 1], [2, 16], [4, 1], [1, 0]),
7975
# DPAS layout
8076
DpasLayout(repeatCount=8, systolic_depth=8, execution_size=16, ops_per_chan=4, threads_per_warp=16,
8177
warps_per_cta=[1, 4], rep_cluster=[1, 2]),
@@ -110,8 +106,7 @@ def warps_per_cta(layout):
110106
parent=DpasLayout(repeatCount=8, systolic_depth=8, execution_size=16, ops_per_chan=1, threads_per_warp=32,
111107
warps_per_cta=[2, 2], rep_cluster=[1, 1]), op_idx=1, k_width=1),
112108
# Slice layout
113-
SliceLayout(dim=1, parent=BlockedLayout([1, 4, 1], [2, 1, 16], [2, 1, 2], [2, 1, 0], [1, 1, 1], [1, 1, 1],
114-
[0, 1, 2])),
109+
SliceLayout(dim=1, parent=BlockedLayout([1, 4, 1], [2, 1, 16], [2, 1, 2], [2, 1, 0])),
115110
]
116111

117112

@@ -136,7 +131,8 @@ def test_block_io(M, N, dtype_str, layout, load_block_ptr, store_block_ptr, tran
136131
block_io = "\"column_major\"" if transpose else "\"row_major\""
137132

138133
strides = "[%c1_i64, %M_i64]" if transpose else "[%N_i64, %c1_i64]"
139-
134+
#breakpoint()
135+
print(layout)
140136
if load_block_ptr:
141137
load_ops = f"""
142138
%src_ptr = tt.make_tensor_ptr %src, [%M_i64, %N_i64], {strides}, [%c0_i32, %c0_i32] {{order = array<i32: 1, 0>}} : <tensor<{M}x{N}x{ty}, #layout>>

python/test/unit/intel/test_core.py

Lines changed: 44 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -62,33 +62,26 @@ def __str__(self):
6262

6363
class BlockedLayout:
6464

65-
def __init__(self, size_per_thread, threads_per_warp, warps_per_cta, order, ctas_per_cga=[1, 1],
66-
cta_split_num=[1, 1], cta_order=[0, 1]):
65+
def __init__(self, size_per_thread, threads_per_warp, warps_per_cta, order):
6766
self.sz_per_thread = size_per_thread
6867
self.threads_per_warp = threads_per_warp
6968
self.warps_per_cta = warps_per_cta
7069
self.order = order
71-
self.ctas_per_cga = ctas_per_cga
72-
self.cta_split_num = cta_split_num
73-
self.cta_order = cta_order
7470

7571
def __str__(self):
76-
return f"#{GPU_DIALECT}.blocked<{{sizePerThread={self.sz_per_thread}, threadsPerWarp={self.threads_per_warp}, warpsPerCTA={self.warps_per_cta}, order={self.order}, CTAsPerCGA={self.ctas_per_cga}, CTASplitNum={self.cta_split_num}, CTAOrder={self.cta_order}}}>"
72+
return f"#{GPU_DIALECT}.blocked<{{sizePerThread={self.sz_per_thread}, threadsPerWarp={self.threads_per_warp}, warpsPerCTA={self.warps_per_cta}, order={self.order}}}>"
7773

7874

7975
class SwizzledSharedLayout:
8076

81-
def __init__(self, vec, per_phase, max_phase, order, ctas_per_cga, cta_split_num, cta_order):
77+
def __init__(self, vec, per_phase, max_phase, order):
8278
self.vec = vec
8379
self.per_phase = per_phase
8480
self.max_phase = max_phase
8581
self.order = order
86-
self.ctas_per_cga = ctas_per_cga
87-
self.cta_split_num = cta_split_num
88-
self.cta_order = cta_order
8982

9083
def __str__(self):
91-
return f"#{GPU_DIALECT}.swizzled_shared<{{vec={self.vec}, perPhase={self.per_phase}, maxPhase={self.max_phase}, order={self.order}, CTAsPerCGA={self.ctas_per_cga}, CTASplitNum={self.cta_split_num}, CTAOrder={self.cta_order}}}>"
84+
return f"#{GPU_DIALECT}.swizzled_shared<{{vec={self.vec}, perPhase={self.per_phase}, maxPhase={self.max_phase}, order={self.order}}}>"
9285

9386

9487
class PaddedSharedLayout:
@@ -172,17 +165,17 @@ def get_reduce_input(dtype_str, shape):
172165

173166

174167
scan_layouts = [
175-
BlockedLayout([1, 4], [4, THREADS_PER_WARP // 4], [4, 1], [0, 1], [1, 1], [1, 1], [0, 1]),
176-
BlockedLayout([1, 4], [8, THREADS_PER_WARP // 8], [4, 1], [0, 1], [1, 1], [1, 1], [0, 1]),
177-
BlockedLayout([4, 1], [4, THREADS_PER_WARP // 4], [1, 4], [0, 1], [1, 1], [1, 1], [0, 1]),
178-
BlockedLayout([2, 2], [4, THREADS_PER_WARP // 4], [2, 2], [0, 1], [1, 1], [1, 1], [0, 1]),
179-
BlockedLayout([2, 2], [8, THREADS_PER_WARP // 8], [2, 2], [0, 1], [1, 1], [1, 1], [0, 1]),
180-
BlockedLayout([1, 4], [4, THREADS_PER_WARP // 4], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]),
181-
BlockedLayout([1, 4], [8, THREADS_PER_WARP // 8], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]),
182-
BlockedLayout([4, 1], [4, THREADS_PER_WARP // 4], [1, 4], [1, 0], [1, 1], [1, 1], [0, 1]),
183-
BlockedLayout([2, 2], [4, THREADS_PER_WARP // 4], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]),
184-
BlockedLayout([2, 2], [8, THREADS_PER_WARP // 8], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]),
185-
BlockedLayout([1, 2], [1, THREADS_PER_WARP // 1], [1, 4], [1, 0], [1, 1], [1, 1], [0, 1]),
168+
BlockedLayout([1, 4], [4, THREADS_PER_WARP // 4], [4, 1], [0, 1]),
169+
BlockedLayout([1, 4], [8, THREADS_PER_WARP // 8], [4, 1], [0, 1]),
170+
BlockedLayout([4, 1], [4, THREADS_PER_WARP // 4], [1, 4], [0, 1]),
171+
BlockedLayout([2, 2], [4, THREADS_PER_WARP // 4], [2, 2], [0, 1]),
172+
BlockedLayout([2, 2], [8, THREADS_PER_WARP // 8], [2, 2], [0, 1]),
173+
BlockedLayout([1, 4], [4, THREADS_PER_WARP // 4], [4, 1], [1, 0]),
174+
BlockedLayout([1, 4], [8, THREADS_PER_WARP // 8], [4, 1], [1, 0]),
175+
BlockedLayout([4, 1], [4, THREADS_PER_WARP // 4], [1, 4], [1, 0]),
176+
BlockedLayout([2, 2], [4, THREADS_PER_WARP // 4], [2, 2], [1, 0]),
177+
BlockedLayout([2, 2], [8, THREADS_PER_WARP // 8], [2, 2], [1, 0]),
178+
BlockedLayout([1, 2], [1, THREADS_PER_WARP // 1], [1, 4], [1, 0]),
186179
]
187180

188181

@@ -254,8 +247,8 @@ def test_scan_layouts(M, N, src_layout, axis, add_overflow_check, device, tmp_pa
254247

255248

256249
layouts = [
257-
BlockedLayout([1, 4], [8, THREADS_PER_WARP // 8], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]),
258-
BlockedLayout([1, 4], [8, THREADS_PER_WARP // 8], [4, 1], [0, 1], [1, 1], [1, 1], [0, 1]),
250+
BlockedLayout([1, 4], [8, THREADS_PER_WARP // 8], [4, 1], [1, 0]),
251+
BlockedLayout([1, 4], [8, THREADS_PER_WARP // 8], [4, 1], [0, 1]),
259252
DpasLayout(repeatCount=8, systolic_depth=8, execution_size=8, ops_per_chan=1, threads_per_warp=32,
260253
warps_per_cta=[4, 1], rep_cluster=[1, 1]),
261254
DpasLayout(repeatCount=8, systolic_depth=8, execution_size=16, ops_per_chan=2, threads_per_warp=32,
@@ -305,8 +298,8 @@ def test_reduce_layouts(M, N, src_layout, axis, epilogue_kind, dtype_str, add_ov
305298
store_range = "%7" if axis == 0 else "%1"
306299
warps = warps_per_cta(src_layout, [M, N])
307300
num_warps = int(np.prod(warps))
308-
blocked = BlockedLayout([1, 1], [32, THREADS_PER_WARP // 32], [4, num_warps // 4], [0, 1], [1, 1], [1, 1], [0, 1])
309-
one_d_layout = BlockedLayout([1], [THREADS_PER_WARP], [num_warps], [0], [1], [1], [0])
301+
blocked = BlockedLayout([1, 1], [32, THREADS_PER_WARP // 32], [4, num_warps // 4], [0, 1])
302+
one_d_layout = BlockedLayout([1], [THREADS_PER_WARP], [num_warps], [0])
310303

311304
expanded_shape = f"1x{N}" if axis == 0 else f"{M}x1"
312305
other_axis = 1 - axis
@@ -397,8 +390,8 @@ def test_reduce_layouts(M, N, src_layout, axis, epilogue_kind, dtype_str, add_ov
397390

398391

399392
layouts = [
400-
BlockedLayout([1, 4], [1, THREADS_PER_WARP], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]),
401-
BlockedLayout([1, 4], [1, THREADS_PER_WARP], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]),
393+
BlockedLayout([1, 4], [1, THREADS_PER_WARP], [4, 1], [1, 0]),
394+
BlockedLayout([1, 4], [1, THREADS_PER_WARP], [2, 2], [1, 0]),
402395
DpasLayout(repeatCount=8, systolic_depth=8, execution_size=8, ops_per_chan=1, threads_per_warp=32,
403396
warps_per_cta=[4, 1], rep_cluster=[1, 1]),
404397
]
@@ -443,8 +436,8 @@ def test_store_op(M, src_layout, device, tmp_path: pathlib.Path):
443436

444437

445438
layouts = [
446-
BlockedLayout([1, 4], [1, THREADS_PER_WARP], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]),
447-
BlockedLayout([1, 4], [1, THREADS_PER_WARP], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]),
439+
BlockedLayout([1, 4], [1, THREADS_PER_WARP], [4, 1], [1, 0]),
440+
BlockedLayout([1, 4], [1, THREADS_PER_WARP], [2, 2], [1, 0]),
448441
DpasLayout(repeatCount=8, systolic_depth=8, execution_size=8, ops_per_chan=1, threads_per_warp=32,
449442
warps_per_cta=[4, 1], rep_cluster=[1, 1])
450443
]
@@ -532,10 +525,10 @@ def test_convert1d_bool(M, src_layout, dst_layout, src_dim, dst_dim, device, tmp
532525

533526

534527
layouts = [
535-
BlockedLayout([1, 4], [1, THREADS_PER_WARP], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]),
536-
BlockedLayout([1, 4], [1, THREADS_PER_WARP], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]),
537-
BlockedLayout([1, 4], [THREADS_PER_WARP // 32, 32], [1, 4], [1, 0], [1, 1], [1, 1], [0, 1]),
538-
BlockedLayout([1, 4], [8, THREADS_PER_WARP // 8], [2, 2], [0, 1], [1, 1], [1, 1], [0, 1])
528+
BlockedLayout([1, 4], [1, THREADS_PER_WARP], [4, 1], [1, 0]),
529+
BlockedLayout([1, 4], [1, THREADS_PER_WARP], [2, 2], [1, 0]),
530+
BlockedLayout([1, 4], [THREADS_PER_WARP // 32, 32], [1, 4], [1, 0]),
531+
BlockedLayout([1, 4], [8, THREADS_PER_WARP // 8], [2, 2], [0, 1])
539532
]
540533

541534

@@ -611,8 +604,8 @@ def test_chain_reduce(M, N, src_layout, op, device, first_axis, tmp_path: pathli
611604
# TODO: backend should be tested separately
612605

613606
layouts = [
614-
BlockedLayout([1, 1], [THREADS_PER_WARP, 1], [2, 2], [0, 1], [1, 1], [1, 1], [0, 1]),
615-
BlockedLayout([1, 16], [8, THREADS_PER_WARP // 8], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]),
607+
BlockedLayout([1, 1], [THREADS_PER_WARP, 1], [2, 2], [0, 1]),
608+
BlockedLayout([1, 16], [8, THREADS_PER_WARP // 8], [4, 1], [1, 0]),
616609
DpasLayout(repeatCount=8, systolic_depth=8, execution_size=8, ops_per_chan=1, threads_per_warp=32,
617610
warps_per_cta=[4, 1], rep_cluster=[1, 1]),
618611
DpasLayout(repeatCount=2, systolic_depth=8, execution_size=8, ops_per_chan=1, threads_per_warp=32,
@@ -621,10 +614,10 @@ def test_chain_reduce(M, N, src_layout, op, device, first_axis, tmp_path: pathli
621614

622615
intermediate_layouts = [
623616
None,
624-
SwizzledSharedLayout(1, 1, 1, [0, 1], [1, 1], [1, 1], [0, 1]),
625-
SwizzledSharedLayout(1, 1, 1, [1, 0], [1, 1], [1, 1], [0, 1]),
626-
SwizzledSharedLayout(4, 2, 4, [1, 0], [1, 1], [1, 1], [0, 1]),
627-
SwizzledSharedLayout(2, 2, 4, [1, 0], [1, 1], [1, 1], [0, 1]),
617+
SwizzledSharedLayout(1, 1, 1, [0, 1]),
618+
SwizzledSharedLayout(1, 1, 1, [1, 0]),
619+
SwizzledSharedLayout(4, 2, 4, [1, 0]),
620+
SwizzledSharedLayout(2, 2, 4, [1, 0]),
628621
]
629622

630623

@@ -736,15 +729,15 @@ def test_convert2d(M, N, src_layout, interm_layout, dst_layout, dtype, device, t
736729

737730

738731
layouts_3d = [
739-
BlockedLayout([4, 4, 1], [1, 8, THREADS_PER_WARP // 8], [2, 2, 1], [2, 1, 0], [1, 1, 1], [1, 1, 1], [0, 1, 2]),
740-
BlockedLayout([1, 1, 4], [8, THREADS_PER_WARP // 8, 1], [2, 1, 2], [1, 2, 0], [1, 1, 1], [1, 1, 1], [0, 1, 2]),
732+
BlockedLayout([4, 4, 1], [1, 8, THREADS_PER_WARP // 8], [2, 2, 1], [2, 1, 0]),
733+
BlockedLayout([1, 1, 4], [8, THREADS_PER_WARP // 8, 1], [2, 1, 2], [1, 2, 0]),
741734
]
742735

743736
shared_layouts_3d = [
744-
SwizzledSharedLayout(1, 1, 1, [2, 1, 0], [1, 1, 1], [1, 1, 1], [0, 1, 2]),
745-
SwizzledSharedLayout(4, 2, 4, [1, 2, 0], [1, 1, 1], [1, 1, 1], [0, 1, 2]),
746-
SwizzledSharedLayout(8, 2, 4, [0, 2, 1], [1, 1, 1], [1, 1, 1], [0, 1, 2]),
747-
SwizzledSharedLayout(4, 2, 1, [2, 0, 1], [1, 1, 1], [1, 1, 1], [0, 1, 2]),
737+
SwizzledSharedLayout(1, 1, 1, [2, 1, 0]),
738+
SwizzledSharedLayout(4, 2, 4, [1, 2, 0]),
739+
SwizzledSharedLayout(8, 2, 4, [0, 2, 1]),
740+
SwizzledSharedLayout(4, 2, 1, [2, 0, 1]),
748741
]
749742

750743

@@ -841,9 +834,9 @@ def test_local_load_store(M, N, K, dist_layout, shared_layout, device, tmp_path:
841834
]
842835

843836
shared_layouts = [
844-
SwizzledSharedLayout(4, 2, 4, [0, 1], [1, 1], [1, 1], [0, 1]),
845-
SwizzledSharedLayout(8, 1, 8, [1, 0], [1, 1], [1, 1], [0, 1]),
846-
SwizzledSharedLayout(16, 1, 16, [1, 0], [1, 1], [1, 1], [0, 1]),
837+
SwizzledSharedLayout(4, 2, 4, [0, 1]),
838+
SwizzledSharedLayout(8, 1, 8, [1, 0]),
839+
SwizzledSharedLayout(16, 1, 16, [1, 0]),
847840
]
848841

849842

@@ -855,7 +848,7 @@ def test_split_subview(M, N, M_tile_size, N_tile_size, device, tmp_path: pathlib
855848
num_repeats_N = triton.cdiv(N, N_tile_size)
856849

857850
ir = f"""
858-
#blocked = #ttg.blocked<{{sizePerThread=[1, 8], threadsPerWarp=[{num_rows_per_warp}, 4], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}}>
851+
#blocked = #ttg.blocked<{{sizePerThread=[1, 8], threadsPerWarp=[{num_rows_per_warp}, 4], warpsPerCTA=[4, 1], order=[1, 0]}}>
859852
#shared = #ttg.swizzled_shared<{{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0]}}>
860853
#smem = #ttg.shared_memory
861854
@@ -989,7 +982,7 @@ def test_local_load_store_dot(M, N, dtype, dist_layout, shared_layout, device, t
989982
]
990983

991984
shared_layouts = [
992-
SwizzledSharedLayout(8, 1, 1, [1, 0], [1, 1], [1, 1], [0, 1]),
985+
SwizzledSharedLayout(8, 1, 1, [1, 0]),
993986
]
994987

995988

python/triton/experimental/gluon/language/intel/_layouts.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,3 +84,7 @@ def __hash__(self):
8484
self.threads_per_warp,
8585
tuple(self.cta_order),
8686
))
87+
88+
@property
89+
def rank(self):
90+
return len(self.warps_per_cta)

test/Conversion/intel/tritongpu_to_gen.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1389,7 +1389,7 @@ tt.func @test_get_program_id(%a: tensor<32x!tt.ptr<i32>, #blocked0>) {
13891389

13901390
// -----
13911391

1392-
#blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [4], CTASplitNum = [1], CTAOrder = [0]}>
1392+
#blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CGALayout = [[0], [0]]}>
13931393
module attributes {"ttg.num-ctas" = 4 : i32, "ttg.num-warps" = 4 : i32} {
13941394
// CHECK-LABEL: test_get_program_id
13951395
tt.func @test_get_program_id(%a: tensor<32x!tt.ptr<i32>, #blocked0>) {

test/TritonIntelGPU/tritongpu_reduce_op_lowering.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
// COM: Tests reduction when threads_per_warp < num_warps.
44

5-
#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [64], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
5+
#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [64], order = [0]}>
66
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 64 : i32, "ttg.threads-per-warp" = 32 : i32} {
77
// CHECK-LABEL: reduce_problem_size_64_threads_per_warp_32
88
tt.func @reduce_problem_size_64_threads_per_warp_32(%f : tensor<2048xi32, #blocked>) {

third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -310,7 +310,7 @@ def Subgroup2DBlockEncodingAttr : DistributedEncoding<"Subgroup2DBlockEncoding",
310310
let parameters = (
311311
ins
312312
ArrayRefParameter<"unsigned">:$warpsPerCTA,
313-
"CTALayoutAttr":$CTALayout,
313+
"CTAEncodingAttr":$CTALayout,
314314
ArrayRefParameter<"unsigned">:$instrShape,
315315
"unsigned":$numBlocks,
316316
ArrayRefParameter<"unsigned">:$order,

0 commit comments

Comments
 (0)