Skip to content

Commit 016e2c1

Browse files
committed
add lit for tptr-to-llvm
1 parent 40e62ae commit 016e2c1

File tree

3 files changed

+1035
-79
lines changed

3 files changed

+1035
-79
lines changed
Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
// RUN: triton-shared-opt --tptr-to-llvm %s | FileCheck %s
2+
3+
// -----// IR Dump Before TPtrToLLVM (tptr-to-llvm) ('builtin.module' operation) //----- //
4+
#loc2 = loc("/tmp/tmpq_mdk7uo/ttshared.mlir":4:33)
5+
#loc3 = loc("/tmp/tmpq_mdk7uo/ttshared.mlir":4:55)
6+
#loc4 = loc("/tmp/tmpq_mdk7uo/ttshared.mlir":4:77)
7+
#loc5 = loc("/tmp/tmpq_mdk7uo/ttshared.mlir":4:89)
8+
#loc6 = loc("/tmp/tmpq_mdk7uo/ttshared.mlir":4:101)
9+
#loc7 = loc("/tmp/tmpq_mdk7uo/ttshared.mlir":4:113)
10+
#loc8 = loc("/tmp/tmpq_mdk7uo/ttshared.mlir":4:125)
11+
#loc9 = loc("/tmp/tmpq_mdk7uo/ttshared.mlir":4:137)
12+
#loc20 = loc("/tmp/tmpq_mdk7uo/ttshared.mlir":15:10)
13+
#loc23 = loc("/tmp/tmpq_mdk7uo/ttshared.mlir":18:10)
14+
#loc26 = loc("/tmp/tmpq_mdk7uo/ttshared.mlir":25:11)
15+
#loc27 = loc("/tmp/tmpq_mdk7uo/ttshared.mlir":26:11)
16+
#loc30 = loc("/tmp/tmpq_mdk7uo/ttshared.mlir":32:11)
17+
#loc34 = loc("/tmp/tmpq_mdk7uo/ttshared.mlir":39:11)
18+
#loc38 = loc("/tmp/tmpq_mdk7uo/ttshared.mlir":46:11)
19+
#loc39 = loc("/tmp/tmpq_mdk7uo/ttshared.mlir":47:11)
20+
#loc42 = loc("/tmp/tmpq_mdk7uo/ttshared.mlir":53:11)
21+
#loc45 = loc("/tmp/tmpq_mdk7uo/ttshared.mlir":59:5)
22+
module {
23+
func.func @bitcast_ptr_as_src(%arg0: memref<*xi32> loc("/tmp/tmpq_mdk7uo/ttshared.mlir":4:33), %arg1: memref<*xi32> loc("/tmp/tmpq_mdk7uo/ttshared.mlir":4:55), %arg2: i32 loc("/tmp/tmpq_mdk7uo/ttshared.mlir":4:77), %arg3: i32 loc("/tmp/tmpq_mdk7uo/ttshared.mlir":4:89), %arg4: i32 loc("/tmp/tmpq_mdk7uo/ttshared.mlir":4:101), %arg5: i32 loc("/tmp/tmpq_mdk7uo/ttshared.mlir":4:113), %arg6: i32 loc("/tmp/tmpq_mdk7uo/ttshared.mlir":4:125), %arg7: i32 loc("/tmp/tmpq_mdk7uo/ttshared.mlir":4:137)) {
24+
%c1 = arith.constant 1 : index loc(#loc10)
25+
%c16 = arith.constant 16 : index loc(#loc10)
26+
%c0 = arith.constant 0 : index loc(#loc10)
27+
%0 = tptr.type_offset i64 : i32 loc(#loc11)
28+
%1 = tptr.type_offset i32 : i32 loc(#loc12)
29+
%c2_i32 = arith.constant 2 : i32 loc(#loc13)
30+
%c1_i32 = arith.constant 1 : i32 loc(#loc14)
31+
%cast = memref.cast %arg1 : memref<*xi32> to memref<1xi32> loc(#loc15)
32+
%2 = tptr.from_memref %cast : memref<1xi32> to <#tptr.default_memory_space> loc(#loc16)
33+
%cast_0 = memref.cast %arg0 : memref<*xi32> to memref<1xi32> loc(#loc17)
34+
%3 = tptr.from_memref %cast_0 : memref<1xi32> to <#tptr.default_memory_space> loc(#loc18)
35+
%alloc = memref.alloc() {alignment = 64 : i64} : memref<16xi32> loc(#loc19)
36+
%alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<16xi32> loc(#loc20)
37+
cf.br ^bb1(%c0 : index) loc(#loc20)
38+
^bb1(%4: index loc("/tmp/tmpq_mdk7uo/ttshared.mlir":15:10)): // 2 preds: ^bb0, ^bb2
39+
%5 = arith.cmpi slt, %4, %c16 : index loc(#loc20)
40+
cf.cond_br %5, ^bb2, ^bb3 loc(#loc20)
41+
^bb2: // pred: ^bb1
42+
memref.store %c2_i32, %alloc_1[%4] : memref<16xi32> loc(#loc20)
43+
%6 = arith.addi %4, %c1 : index loc(#loc20)
44+
cf.br ^bb1(%6 : index) loc(#loc20)
45+
^bb3: // pred: ^bb1
46+
%7 = arith.muli %c1_i32, %1 : i32 loc(#loc21)
47+
%8 = tptr.ptradd %3 %7 : <#tptr.default_memory_space>, i32 to <#tptr.default_memory_space> loc(#loc22)
48+
cf.br ^bb4(%c0 : index) loc(#loc23)
49+
^bb4(%9: index loc("/tmp/tmpq_mdk7uo/ttshared.mlir":18:10)): // 2 preds: ^bb3, ^bb5
50+
%10 = arith.cmpi slt, %9, %c16 : index loc(#loc23)
51+
cf.cond_br %10, ^bb5, ^bb6 loc(#loc23)
52+
^bb5: // pred: ^bb4
53+
%11 = arith.index_cast %9 : index to i32 loc(#loc24)
54+
memref.store %11, %alloc[%9] : memref<16xi32> loc(#loc23)
55+
%12 = arith.addi %9, %c1 : index loc(#loc23)
56+
cf.br ^bb4(%12 : index) loc(#loc23)
57+
^bb6: // pred: ^bb4
58+
%alloc_2 = memref.alloc() {alignment = 64 : i64} : memref<16x!ptr.ptr<#tptr.default_memory_space>> loc(#loc25)
59+
cf.br ^bb7(%c0 : index) loc(#loc26)
60+
^bb7(%13: index loc("/tmp/tmpq_mdk7uo/ttshared.mlir":25:11)): // 2 preds: ^bb6, ^bb8
61+
%14 = arith.cmpi slt, %13, %c16 : index loc(#loc26)
62+
cf.cond_br %14, ^bb8, ^bb9 loc(#loc26)
63+
^bb8: // pred: ^bb7
64+
memref.store %8, %alloc_2[%13] : memref<16x!ptr.ptr<#tptr.default_memory_space>> loc(#loc26)
65+
%15 = arith.addi %13, %c1 : index loc(#loc26)
66+
cf.br ^bb7(%15 : index) loc(#loc26)
67+
^bb9: // pred: ^bb7
68+
cf.br ^bb10(%c0 : index) loc(#loc27)
69+
^bb10(%16: index loc("/tmp/tmpq_mdk7uo/ttshared.mlir":26:11)): // 2 preds: ^bb9, ^bb11
70+
%17 = arith.cmpi slt, %16, %c16 : index loc(#loc27)
71+
cf.cond_br %17, ^bb11, ^bb12 loc(#loc27)
72+
^bb11: // pred: ^bb10
73+
%18 = memref.load %alloc_2[%16] : memref<16x!ptr.ptr<#tptr.default_memory_space>> loc(#loc27)
74+
%19 = memref.load %alloc[%16] : memref<16xi32> loc(#loc27)
75+
%20 = arith.muli %19, %0 : i32 loc(#loc28)
76+
%21 = tptr.ptradd %18 %20 : <#tptr.default_memory_space>, i32 to <#tptr.default_memory_space> loc(#loc29)
77+
memref.store %21, %alloc_2[%16] : memref<16x!ptr.ptr<#tptr.default_memory_space>> loc(#loc27)
78+
%22 = arith.addi %16, %c1 : index loc(#loc27)
79+
cf.br ^bb10(%22 : index) loc(#loc27)
80+
^bb12: // pred: ^bb10
81+
cf.br ^bb13(%c0 : index) loc(#loc30)
82+
^bb13(%23: index loc("/tmp/tmpq_mdk7uo/ttshared.mlir":32:11)): // 2 preds: ^bb12, ^bb14
83+
%24 = arith.cmpi slt, %23, %c16 : index loc(#loc30)
84+
cf.cond_br %24, ^bb14, ^bb15 loc(#loc30)
85+
^bb14: // pred: ^bb13
86+
%25 = memref.load %alloc_2[%23] : memref<16x!ptr.ptr<#tptr.default_memory_space>> loc(#loc30)
87+
%26 = memref.load %alloc_1[%23] : memref<16xi32> loc(#loc30)
88+
%27 = arith.muli %26, %0 : i32 loc(#loc31)
89+
%28 = tptr.ptradd %25 %27 : <#tptr.default_memory_space>, i32 to <#tptr.default_memory_space> loc(#loc32)
90+
memref.store %28, %alloc_2[%23] : memref<16x!ptr.ptr<#tptr.default_memory_space>> loc(#loc30)
91+
%29 = arith.addi %23, %c1 : index loc(#loc30)
92+
cf.br ^bb13(%29 : index) loc(#loc30)
93+
^bb15: // pred: ^bb13
94+
%alloc_3 = memref.alloc() {alignment = 64 : i64} : memref<16xi64> loc(#loc33)
95+
cf.br ^bb16(%c0 : index) loc(#loc34)
96+
^bb16(%30: index loc("/tmp/tmpq_mdk7uo/ttshared.mlir":39:11)): // 2 preds: ^bb15, ^bb17
97+
%31 = arith.cmpi slt, %30, %c16 : index loc(#loc34)
98+
cf.cond_br %31, ^bb17, ^bb18 loc(#loc34)
99+
^bb17: // pred: ^bb16
100+
%32 = memref.load %alloc_2[%30] : memref<16x!ptr.ptr<#tptr.default_memory_space>> loc(#loc34)
101+
%33 = tptr.to_memref %32 : <#tptr.default_memory_space> to memref<1xi64> loc(#loc35)
102+
%34 = memref.load %33[%c0] : memref<1xi64> loc(#loc36)
103+
memref.store %34, %alloc_3[%30] : memref<16xi64> loc(#loc34)
104+
%35 = arith.addi %30, %c1 : index loc(#loc34)
105+
cf.br ^bb16(%35 : index) loc(#loc34)
106+
^bb18: // pred: ^bb16
107+
%36 = tptr.ptradd %2 %7 : <#tptr.default_memory_space>, i32 to <#tptr.default_memory_space> loc(#loc37)
108+
cf.br ^bb19(%c0 : index) loc(#loc38)
109+
^bb19(%37: index loc("/tmp/tmpq_mdk7uo/ttshared.mlir":46:11)): // 2 preds: ^bb18, ^bb20
110+
%38 = arith.cmpi slt, %37, %c16 : index loc(#loc38)
111+
cf.cond_br %38, ^bb20, ^bb21 loc(#loc38)
112+
^bb20: // pred: ^bb19
113+
memref.store %36, %alloc_2[%37] : memref<16x!ptr.ptr<#tptr.default_memory_space>> loc(#loc38)
114+
%39 = arith.addi %37, %c1 : index loc(#loc38)
115+
cf.br ^bb19(%39 : index) loc(#loc38)
116+
^bb21: // pred: ^bb19
117+
cf.br ^bb22(%c0 : index) loc(#loc39)
118+
^bb22(%40: index loc("/tmp/tmpq_mdk7uo/ttshared.mlir":47:11)): // 2 preds: ^bb21, ^bb23
119+
%41 = arith.cmpi slt, %40, %c16 : index loc(#loc39)
120+
cf.cond_br %41, ^bb23, ^bb24 loc(#loc39)
121+
^bb23: // pred: ^bb22
122+
%42 = memref.load %alloc_2[%40] : memref<16x!ptr.ptr<#tptr.default_memory_space>> loc(#loc39)
123+
%43 = memref.load %alloc[%40] : memref<16xi32> loc(#loc39)
124+
%44 = arith.muli %43, %0 : i32 loc(#loc40)
125+
%45 = tptr.ptradd %42 %44 : <#tptr.default_memory_space>, i32 to <#tptr.default_memory_space> loc(#loc41)
126+
memref.store %45, %alloc_2[%40] : memref<16x!ptr.ptr<#tptr.default_memory_space>> loc(#loc39)
127+
%46 = arith.addi %40, %c1 : index loc(#loc39)
128+
cf.br ^bb22(%46 : index) loc(#loc39)
129+
^bb24: // pred: ^bb22
130+
cf.br ^bb25(%c0 : index) loc(#loc42)
131+
^bb25(%47: index loc("/tmp/tmpq_mdk7uo/ttshared.mlir":53:11)): // 2 preds: ^bb24, ^bb26
132+
%48 = arith.cmpi slt, %47, %c16 : index loc(#loc42)
133+
cf.cond_br %48, ^bb26, ^bb27 loc(#loc42)
134+
^bb26: // pred: ^bb25
135+
%49 = memref.load %alloc_2[%47] : memref<16x!ptr.ptr<#tptr.default_memory_space>> loc(#loc42)
136+
%50 = memref.load %alloc_1[%47] : memref<16xi32> loc(#loc42)
137+
%51 = arith.muli %50, %0 : i32 loc(#loc43)
138+
%52 = tptr.ptradd %49 %51 : <#tptr.default_memory_space>, i32 to <#tptr.default_memory_space> loc(#loc44)
139+
memref.store %52, %alloc_2[%47] : memref<16x!ptr.ptr<#tptr.default_memory_space>> loc(#loc42)
140+
%53 = arith.addi %47, %c1 : index loc(#loc42)
141+
cf.br ^bb25(%53 : index) loc(#loc42)
142+
^bb27: // pred: ^bb25
143+
cf.br ^bb28(%c0 : index) loc(#loc45)
144+
^bb28(%54: index loc("/tmp/tmpq_mdk7uo/ttshared.mlir":59:5)): // 2 preds: ^bb27, ^bb29
145+
%55 = arith.cmpi slt, %54, %c16 : index loc(#loc45)
146+
cf.cond_br %55, ^bb29, ^bb30 loc(#loc45)
147+
^bb29: // pred: ^bb28
148+
%56 = memref.load %alloc_2[%54] : memref<16x!ptr.ptr<#tptr.default_memory_space>> loc(#loc45)
149+
%57 = memref.load %alloc_3[%54] : memref<16xi64> loc(#loc45)
150+
%58 = tptr.to_memref %56 : <#tptr.default_memory_space> to memref<1xi64> loc(#loc46)
151+
memref.store %57, %58[%c0] : memref<1xi64> loc(#loc47)
152+
%59 = arith.addi %54, %c1 : index loc(#loc45)
153+
cf.br ^bb28(%59 : index) loc(#loc45)
154+
^bb30: // pred: ^bb28
155+
return loc(#loc48)
156+
} loc(#loc1)
157+
} loc(#loc)
158+
#loc = loc("/tmp/tmpq_mdk7uo/ttshared.mlir":3:1)
159+
#loc1 = loc("/tmp/tmpq_mdk7uo/ttshared.mlir":4:3)
160+
#loc10 = loc(unknown)
161+
#loc11 = loc("/tmp/tmpq_mdk7uo/ttshared.mlir":6:10)
162+
#loc12 = loc("/tmp/tmpq_mdk7uo/ttshared.mlir":7:10)
163+
#loc13 = loc("/tmp/tmpq_mdk7uo/ttshared.mlir":8:15)
164+
#loc14 = loc("/tmp/tmpq_mdk7uo/ttshared.mlir":9:15)
165+
#loc15 = loc("/tmp/tmpq_mdk7uo/ttshared.mlir":10:13)
166+
#loc16 = loc("/tmp/tmpq_mdk7uo/ttshared.mlir":11:10)
167+
#loc17 = loc("/tmp/tmpq_mdk7uo/ttshared.mlir":12:15)
168+
#loc18 = loc("/tmp/tmpq_mdk7uo/ttshared.mlir":13:10)
169+
#loc19 = loc("/tmp/tmpq_mdk7uo/ttshared.mlir":14:10)
170+
#loc21 = loc("/tmp/tmpq_mdk7uo/ttshared.mlir":16:10)
171+
#loc22 = loc("/tmp/tmpq_mdk7uo/ttshared.mlir":17:10)
172+
#loc24 = loc("/tmp/tmpq_mdk7uo/ttshared.mlir":21:13)
173+
#loc25 = loc("/tmp/tmpq_mdk7uo/ttshared.mlir":24:10)
174+
#loc28 = loc("/tmp/tmpq_mdk7uo/ttshared.mlir":28:13)
175+
#loc29 = loc("/tmp/tmpq_mdk7uo/ttshared.mlir":29:13)
176+
#loc31 = loc("/tmp/tmpq_mdk7uo/ttshared.mlir":34:13)
177+
#loc32 = loc("/tmp/tmpq_mdk7uo/ttshared.mlir":35:13)
178+
#loc33 = loc("/tmp/tmpq_mdk7uo/ttshared.mlir":38:11)
179+
#loc35 = loc("/tmp/tmpq_mdk7uo/ttshared.mlir":41:13)
180+
#loc36 = loc("/tmp/tmpq_mdk7uo/ttshared.mlir":42:13)
181+
#loc37 = loc("/tmp/tmpq_mdk7uo/ttshared.mlir":45:11)
182+
#loc40 = loc("/tmp/tmpq_mdk7uo/ttshared.mlir":49:13)
183+
#loc41 = loc("/tmp/tmpq_mdk7uo/ttshared.mlir":50:13)
184+
#loc43 = loc("/tmp/tmpq_mdk7uo/ttshared.mlir":55:13)
185+
#loc44 = loc("/tmp/tmpq_mdk7uo/ttshared.mlir":56:13)
186+
#loc46 = loc("/tmp/tmpq_mdk7uo/ttshared.mlir":61:13)
187+
#loc47 = loc("/tmp/tmpq_mdk7uo/ttshared.mlir":62:7)
188+
#loc48 = loc("/tmp/tmpq_mdk7uo/ttshared.mlir":65:5)
189+
190+
191+
192+
// CHECK-NOT: tptr.
193+
// CHECK-NOT: ptr.ptr

0 commit comments

Comments
 (0)