Skip to content

Commit a92234a

Browse files
committed
further reduction
Signed-off-by: Anatoly Myachev <[email protected]>
1 parent 55a8540 commit a92234a

File tree

1 file changed

+13
-42
lines changed

1 file changed

+13
-42
lines changed

python/test/unit/intel/test_regressions.py

Lines changed: 13 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -61,12 +61,11 @@ def test_kernel_from_09_tutorial(device, tmp_path: pathlib.Path):
6161
#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
6262
#blocked1 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}>
6363
#blocked2 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
64-
#blocked3 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}>
6564
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
6665
#shared1 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0, 1]}>
6766
#smem = #ttg.shared_memory
6867
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "xpu", "ttg.threads-per-warp" = 32 : i32, ttig.min_sg_size = 8 : i32, ttig.support_bf16_conversion, ttig.support_dpas, ttig.target_arch = "spir64"} {
69-
tt.func public @matmul_kernel(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32} ) attributes {noinline = false} {
68+
tt.func public @matmul_kernel(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} {
7069
%cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked>
7170
%c63_i32 = arith.constant 63 : i32
7271
%c127_i32 = arith.constant 127 : i32
@@ -79,82 +78,54 @@ def test_kernel_from_09_tutorial(device, tmp_path: pathlib.Path):
7978
%c128_i32 = arith.constant 128 : i32
8079
%cst_2 = arith.constant dense<0> : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>>
8180
%cst_3 = arith.constant dense<0> : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked1}>>
82-
%cst_4 = arith.constant dense<64> : tensor<128x64xi32, #blocked2>
83-
%cst_5 = arith.constant dense<64> : tensor<64x128xi32, #blocked1>
8481
%0 = tt.get_program_id x : i32
8582
%1 = arith.addi %arg3, %c127_i32 : i32
8683
%2 = arith.divsi %1, %c128_i32 : i32
87-
%3 = arith.addi %arg4, %c127_i32 : i32
88-
%4 = arith.divsi %3, %c128_i32 : i32
89-
%5 = arith.muli %4, %c8_i32 : i32
84+
%5 = arith.muli %2, %c8_i32 : i32
9085
%6 = arith.divsi %0, %5 : i32
9186
%7 = arith.muli %6, %c8_i32 : i32
9287
%8 = arith.subi %2, %7 : i32
9388
%9 = arith.minsi %8, %c8_i32 : i32
94-
%10 = arith.remsi %0, %9 : i32
95-
%11 = arith.addi %7, %10 : i32
9689
%12 = arith.remsi %0, %5 : i32
9790
%13 = arith.divsi %12, %9 : i32
98-
%14 = arith.muli %11, %c128_i32 : i32
9991
%15 = arith.muli %13, %c128_i32 : i32
100-
%16 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>>
10192
%18 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked1}>>
102-
%20 = tt.splat %14 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>>
103-
%22 = arith.addi %20, %16 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>>
93+
%20 = tt.splat %c128_i32 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>>
10494
%24 = tt.splat %15 : i32 -> tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked1}>>
10595
%26 = arith.addi %24, %18 : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked1}>>
10696
%28 = tt.splat %arg3 : i32 -> tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>>
107-
%29 = arith.cmpi slt, %22, %28 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>>
108-
%30 = arith.select %29, %22, %cst_2 {tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi1, #ttg.slice<{dim = 1, parent = #blocked2}>>, tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>>
109-
%31 = tt.splat %arg4 : i32 -> tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked1}>>
97+
%29 = arith.cmpi slt, %20, %28 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>>
98+
%31 = tt.splat %arg3 : i32 -> tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked1}>>
11099
%32 = arith.cmpi slt, %26, %31 : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked1}>>
111100
%33 = arith.select %32, %26, %cst_3 {tt.contiguity = dense<128> : tensor<1xi32>, tt.divisibility = dense<128> : tensor<1xi32>} : tensor<128xi1, #ttg.slice<{dim = 0, parent = #blocked1}>>, tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked1}>>
112-
%34 = tt.expand_dims %30 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked2}>> -> tensor<128x1xi32, #blocked2>
113-
%35 = tt.splat %arg6 : i32 -> tensor<128x1xi32, #blocked2>
114-
%36 = arith.muli %34, %35 : tensor<128x1xi32, #blocked2>
115101
%37 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked2}>>
116102
%38 = tt.expand_dims %37 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked2}>> -> tensor<1x64xi32, #blocked2>
117-
%39 = tt.broadcast %36 : tensor<128x1xi32, #blocked2> -> tensor<128x64xi32, #blocked2>
118-
%40 = tt.broadcast %38 : tensor<1x64xi32, #blocked2> -> tensor<128x64xi32, #blocked2>
119-
%41 = arith.addi %39, %40 : tensor<128x64xi32, #blocked2>
120103
%42 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<128x64x!tt.ptr<f32>, #blocked2>
121-
%43 = tt.addptr %42, %41 : tensor<128x64x!tt.ptr<f32>, #blocked2>, tensor<128x64xi32, #blocked2>
122104
%44 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>
123105
%45 = tt.expand_dims %44 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<64x1xi32, #blocked1>
124106
%46 = tt.expand_dims %33 {axis = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x128xi32, #blocked1>
125-
%47 = tt.splat %arg7 : i32 -> tensor<1x128xi32, #blocked1>
126-
%48 = arith.muli %46, %47 : tensor<1x128xi32, #blocked1>
127-
%49 = tt.broadcast %45 : tensor<64x1xi32, #blocked1> -> tensor<64x128xi32, #blocked1>
128-
%50 = tt.broadcast %48 : tensor<1x128xi32, #blocked1> -> tensor<64x128xi32, #blocked1>
129-
%51 = arith.addi %49, %50 : tensor<64x128xi32, #blocked1>
130-
%52 = tt.splat %arg1 : !tt.ptr<f32> -> tensor<64x128x!tt.ptr<f32>, #blocked1>
131-
%53 = tt.addptr %52, %51 : tensor<64x128x!tt.ptr<f32>, #blocked1>, tensor<64x128xi32, #blocked1>
132-
%54 = arith.addi %arg5, %c63_i32 : i32
133-
%55 = arith.divsi %54, %c64_i32 : i32
134-
%56 = arith.remsi %arg5, %c64_i32 : i32
135-
%57 = arith.cmpi eq, %56, %c0_i32 : i32
136-
%58 = arith.cmpi sgt, %arg5, %c64_i32 : i32
137-
%59 = arith.andi %57, %58 : i1
107+
%50 = tt.broadcast %46 : tensor<1x128xi32, #blocked1> -> tensor<64x128xi32, #blocked1>
108+
%52 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<64x128x!tt.ptr<f32>, #blocked1>
109+
%53 = tt.addptr %52, %50 : tensor<64x128x!tt.ptr<f32>, #blocked1>, tensor<64x128xi32, #blocked1>
138110
139111
%80 = arith.muli %c0_i32, %c64_i32 : i32
140112
%81 = arith.subi %arg5, %80 : i32
141113
%82 = tt.splat %81 : i32 -> tensor<1x64xi32, #blocked2>
142114
%83 = arith.cmpi slt, %38, %82 : tensor<1x64xi32, #blocked2>
143115
%84 = tt.broadcast %83 : tensor<1x64xi1, #blocked2> -> tensor<128x64xi1, #blocked2>
144-
%85 = tt.load %43, %84, %cst_1 : tensor<128x64x!tt.ptr<f32>, #blocked2>
116+
%85 = tt.load %42, %84, %cst_1 : tensor<128x64x!tt.ptr<f32>, #blocked2>
145117
%86 = tt.splat %81 : i32 -> tensor<64x1xi32, #blocked1>
146118
%87 = arith.cmpi slt, %45, %86 : tensor<64x1xi32, #blocked1>
147119
%88 = tt.broadcast %87 : tensor<64x1xi1, #blocked1> -> tensor<64x128xi1, #blocked1>
148120
%89 = tt.load %53, %88, %cst_0 : tensor<64x128x!tt.ptr<f32>, #blocked1>
149121
%91 = ttg.local_alloc %85 : (tensor<128x64xf32, #blocked2>) -> !ttg.memdesc<128x64xf32, #shared, #smem>
150122
%92 = ttg.local_load %91 : !ttg.memdesc<128x64xf32, #shared, #smem> -> tensor<128x64xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked}>>
151123
%94 = ttg.local_alloc %89 : (tensor<64x128xf32, #blocked1>) -> !ttg.memdesc<64x128xf32, #shared1, #smem>
152-
%95 = ttg.local_load %94 : !ttg.memdesc<64x128xf32, #shared1, #smem> -> tensor<64x128xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>
153-
%96 = tt.dot %92, %95, %cst, inputPrecision = tf32 : tensor<128x64xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x128xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x128xf32, #blocked>
154-
%97 = tt.addptr %43, %cst_4 : tensor<128x64x!tt.ptr<f32>, #blocked2>, tensor<128x64xi32, #blocked2>
155-
%98 = tt.addptr %53, %cst_5 : tensor<64x128x!tt.ptr<f32>, #blocked1>, tensor<64x128xi32, #blocked1>
124+
%cst_test = arith.constant dense<1.11111116> : tensor<128x64xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked}>>
125+
%cst_test2 = arith.constant dense<1.11111116> : tensor<64x128xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>
126+
%96 = tt.dot %92, %cst_test2, %cst, inputPrecision = tf32 : tensor<128x64xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x128xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x128xf32, #blocked>
156127
157-
%78 = ttg.convert_layout %96 : tensor<128x128xf32, #blocked> -> tensor<128x128xf32, #blocked3>
128+
%78 = ttg.convert_layout %96 : tensor<128x128xf32, #blocked> -> tensor<128x128xf32, #blocked2>
158129
tt.return
159130
}
160131
}

0 commit comments

Comments
 (0)