Skip to content

Commit 84f204b

Browse files
[RemoveLayoutConversion] Increase convert layout cost (#5477)
The cost of smem load/store and synchronization are higher on Intel GPUs compared to NV. This PR simply increases it by a factor of 2. Issue #5476 is created to further refine the remove layout conversion cost model. Benchmark CI: https://github.com/intel/intel-xpu-backend-for-triton/actions/runs/19356260840 (good) Fixes #5124 --------- Signed-off-by: Whitney Tsang <[email protected]>
1 parent 8a774bd commit 84f204b

File tree

2 files changed

+251
-1
lines changed

2 files changed

+251
-1
lines changed
Lines changed: 248 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,248 @@
1+
// RUN: triton-opt %s -tritonintelgpu-remove-layout-conversions 2>&1 | FileCheck %s
2+
3+
// CHECK-NOT: ttg.convert_layout
4+
#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
5+
#blocked1 = #ttg.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
6+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "xpu", "ttg.threads-per-warp" = 32 : i32, ttig.min_sg_size = 16 : i32, ttig.support_bf16_conversion, ttig.support_dpas, ttig.support_sg_2d_block, ttig.target_arch = "spir64"} {
7+
tt.func public @triton_poi_fused_max_pool2d_with_indices_max_pool2d_with_indices_backward_139(%in_ptr0: !tt.ptr<i8> {tt.divisibility = 16 : i32}, %in_ptr1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %out_ptr0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %xnumel: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} {
8+
%cst = arith.constant dense<0.000000e+00> : tensor<1024xf32, #blocked>
9+
%cst_0 = arith.constant dense<294> : tensor<1024xi32, #blocked>
10+
%cst_1 = arith.constant dense<144> : tensor<1024xi32, #blocked>
11+
%cst_2 = arith.constant dense<3> : tensor<1024xi32, #blocked>
12+
%cst_3 = arith.constant dense<9> : tensor<1024xi32, #blocked>
13+
%cst_4 = arith.constant dense<341056> : tensor<1024xi32, #blocked>
14+
%cst_5 = arith.constant dense<4672> : tensor<1024xi32, #blocked>
15+
%cst_6 = arith.constant dense<73> : tensor<1024xi32, #blocked>
16+
%cst_7 = arith.constant dense<1> : tensor<1024xi32, #blocked>
17+
%cst_8 = arith.constant dense<2> : tensor<1024xi32, #blocked>
18+
%cst_9 = arith.constant dense<0> : tensor<1024xi32, #blocked>
19+
%cst_10 = arith.constant dense<-1> : tensor<1024xi32, #blocked>
20+
%cst_11 = arith.constant dense<21609> : tensor<1024xi32, #blocked>
21+
%cst_12 = arith.constant dense<1382976> : tensor<1024xi32, #blocked>
22+
%cst_13 = arith.constant dense<9408> : tensor<1024xi32, #blocked>
23+
%cst_14 = arith.constant dense<147> : tensor<1024xi32, #blocked>
24+
%cst_15 = arith.constant dense<64> : tensor<1024xi32, #blocked>
25+
%c1024_i32 = arith.constant 1024 : i32
26+
%0 = tt.get_program_id x : i32
27+
%1 = arith.muli %0, %c1024_i32 : i32
28+
%2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked>
29+
%3 = tt.splat %1 : i32 -> tensor<1024xi32, #blocked>
30+
%4 = arith.addi %3, %2 : tensor<1024xi32, #blocked>
31+
%5 = arith.remsi %4, %cst_15 : tensor<1024xi32, #blocked>
32+
%6 = arith.divsi %4, %cst_15 : tensor<1024xi32, #blocked>
33+
%7 = arith.remsi %6, %cst_14 : tensor<1024xi32, #blocked>
34+
%8 = arith.divsi %4, %cst_13 : tensor<1024xi32, #blocked>
35+
%9 = arith.remsi %8, %cst_14 : tensor<1024xi32, #blocked>
36+
%10 = arith.divsi %4, %cst_12 : tensor<1024xi32, #blocked>
37+
%11 = arith.remsi %6, %cst_11 : tensor<1024xi32, #blocked>
38+
%12 = arith.addi %7, %cst_10 : tensor<1024xi32, #blocked>
39+
%13 = arith.divsi %12, %cst_8 : tensor<1024xi32, #blocked>
40+
%14 = arith.remsi %12, %cst_8 : tensor<1024xi32, #blocked>
41+
%15 = arith.cmpi ne, %14, %cst_9 : tensor<1024xi32, #blocked>
42+
%16 = arith.subi %13, %cst_7 : tensor<1024xi32, #blocked>
43+
%17 = arith.select %15, %16, %13 : tensor<1024xi1, #blocked>, tensor<1024xi32, #blocked>
44+
%18 = arith.cmpi slt, %12, %cst_9 : tensor<1024xi32, #blocked>
45+
%19 = arith.select %18, %17, %13 : tensor<1024xi1, #blocked>, tensor<1024xi32, #blocked>
46+
%20 = arith.cmpi sgt, %19, %cst_9 : tensor<1024xi32, #blocked>
47+
%21 = arith.extui %20 : tensor<1024xi1, #blocked> to tensor<1024xi32, #blocked>
48+
%22 = arith.muli %19, %21 : tensor<1024xi32, #blocked>
49+
%23 = arith.divsi %7, %cst_8 : tensor<1024xi32, #blocked>
50+
%24 = arith.addi %23, %cst_7 : tensor<1024xi32, #blocked>
51+
%25 = arith.cmpi sge, %24, %cst_6 : tensor<1024xi32, #blocked>
52+
%26 = arith.extui %25 : tensor<1024xi1, #blocked> to tensor<1024xi32, #blocked>
53+
%27 = arith.muli %26, %cst_6 : tensor<1024xi32, #blocked>
54+
%28 = arith.cmpi slt, %24, %cst_6 : tensor<1024xi32, #blocked>
55+
%29 = arith.extui %28 : tensor<1024xi1, #blocked> to tensor<1024xi32, #blocked>
56+
%30 = arith.muli %24, %29 : tensor<1024xi32, #blocked>
57+
%31 = arith.addi %27, %30 : tensor<1024xi32, #blocked>
58+
%32 = arith.addi %31, %cst_10 : tensor<1024xi32, #blocked>
59+
%33 = arith.cmpi sle, %22, %32 : tensor<1024xi32, #blocked>
60+
%34 = arith.extui %33 : tensor<1024xi1, #blocked> to tensor<1024xi32, #blocked>
61+
%35 = arith.muli %22, %34 : tensor<1024xi32, #blocked>
62+
%36 = arith.cmpi slt, %32, %22 : tensor<1024xi32, #blocked>
63+
%37 = arith.extui %36 : tensor<1024xi1, #blocked> to tensor<1024xi32, #blocked>
64+
%38 = arith.muli %32, %37 : tensor<1024xi32, #blocked>
65+
%39 = arith.addi %35, %38 : tensor<1024xi32, #blocked>
66+
%40 = arith.muli %39, %cst_15 : tensor<1024xi32, #blocked>
67+
%41 = arith.addi %5, %40 : tensor<1024xi32, #blocked>
68+
%42 = arith.addi %9, %cst_10 : tensor<1024xi32, #blocked>
69+
%43 = arith.divsi %42, %cst_8 : tensor<1024xi32, #blocked>
70+
%44 = arith.remsi %42, %cst_8 : tensor<1024xi32, #blocked>
71+
%45 = arith.cmpi ne, %44, %cst_9 : tensor<1024xi32, #blocked>
72+
%46 = arith.subi %43, %cst_7 : tensor<1024xi32, #blocked>
73+
%47 = arith.select %45, %46, %43 : tensor<1024xi1, #blocked>, tensor<1024xi32, #blocked>
74+
%48 = arith.cmpi slt, %42, %cst_9 : tensor<1024xi32, #blocked>
75+
%49 = arith.select %48, %47, %43 : tensor<1024xi1, #blocked>, tensor<1024xi32, #blocked>
76+
%50 = arith.cmpi sgt, %49, %cst_9 : tensor<1024xi32, #blocked>
77+
%51 = arith.extui %50 : tensor<1024xi1, #blocked> to tensor<1024xi32, #blocked>
78+
%52 = arith.muli %49, %51 : tensor<1024xi32, #blocked>
79+
%53 = arith.divsi %9, %cst_8 : tensor<1024xi32, #blocked>
80+
%54 = arith.addi %53, %cst_7 : tensor<1024xi32, #blocked>
81+
%55 = arith.cmpi sge, %54, %cst_6 : tensor<1024xi32, #blocked>
82+
%56 = arith.extui %55 : tensor<1024xi1, #blocked> to tensor<1024xi32, #blocked>
83+
%57 = arith.muli %56, %cst_6 : tensor<1024xi32, #blocked>
84+
%58 = arith.cmpi slt, %54, %cst_6 : tensor<1024xi32, #blocked>
85+
%59 = arith.extui %58 : tensor<1024xi1, #blocked> to tensor<1024xi32, #blocked>
86+
%60 = arith.muli %54, %59 : tensor<1024xi32, #blocked>
87+
%61 = arith.addi %57, %60 : tensor<1024xi32, #blocked>
88+
%62 = arith.addi %61, %cst_10 : tensor<1024xi32, #blocked>
89+
%63 = arith.cmpi sle, %52, %62 : tensor<1024xi32, #blocked>
90+
%64 = arith.extui %63 : tensor<1024xi1, #blocked> to tensor<1024xi32, #blocked>
91+
%65 = arith.muli %52, %64 : tensor<1024xi32, #blocked>
92+
%66 = arith.cmpi slt, %62, %52 : tensor<1024xi32, #blocked>
93+
%67 = arith.extui %66 : tensor<1024xi1, #blocked> to tensor<1024xi32, #blocked>
94+
%68 = arith.muli %62, %67 : tensor<1024xi32, #blocked>
95+
%69 = arith.addi %65, %68 : tensor<1024xi32, #blocked>
96+
%70 = arith.muli %69, %cst_5 : tensor<1024xi32, #blocked>
97+
%71 = arith.addi %41, %70 : tensor<1024xi32, #blocked>
98+
%72 = arith.muli %10, %cst_4 : tensor<1024xi32, #blocked>
99+
%73 = arith.addi %71, %72 : tensor<1024xi32, #blocked>
100+
%74 = tt.splat %in_ptr0 : !tt.ptr<i8> -> tensor<1024x!tt.ptr<i8>, #blocked>
101+
%75 = tt.addptr %74, %73 : tensor<1024x!tt.ptr<i8>, #blocked>, tensor<1024xi32, #blocked>
102+
%76 = ttg.convert_layout %75 : tensor<1024x!tt.ptr<i8>, #blocked> -> tensor<1024x!tt.ptr<i8>, #blocked1>
103+
%77 = tt.load %76 : tensor<1024x!tt.ptr<i8>, #blocked1>
104+
%78 = ttg.convert_layout %77 : tensor<1024xi8, #blocked1> -> tensor<1024xi8, #blocked>
105+
%79 = tt.splat %in_ptr1 : !tt.ptr<f16> -> tensor<1024x!tt.ptr<f16>, #blocked>
106+
%80 = tt.addptr %79, %73 : tensor<1024x!tt.ptr<f16>, #blocked>, tensor<1024xi32, #blocked>
107+
%81 = ttg.convert_layout %80 : tensor<1024x!tt.ptr<f16>, #blocked> -> tensor<1024x!tt.ptr<f16>, #blocked1>
108+
%82 = tt.load %81 : tensor<1024x!tt.ptr<f16>, #blocked1>
109+
%83 = ttg.convert_layout %82 : tensor<1024xf16, #blocked1> -> tensor<1024xf16, #blocked>
110+
%84 = arith.extf %83 : tensor<1024xf16, #blocked> to tensor<1024xf32, #blocked>
111+
%85 = arith.addi %22, %cst_7 : tensor<1024xi32, #blocked>
112+
%86 = arith.cmpi sle, %85, %32 : tensor<1024xi32, #blocked>
113+
%87 = arith.extui %86 : tensor<1024xi1, #blocked> to tensor<1024xi32, #blocked>
114+
%88 = arith.muli %85, %87 : tensor<1024xi32, #blocked>
115+
%89 = arith.cmpi slt, %32, %85 : tensor<1024xi32, #blocked>
116+
%90 = arith.extui %89 : tensor<1024xi1, #blocked> to tensor<1024xi32, #blocked>
117+
%91 = arith.muli %32, %90 : tensor<1024xi32, #blocked>
118+
%92 = arith.addi %88, %91 : tensor<1024xi32, #blocked>
119+
%93 = arith.muli %92, %cst_15 : tensor<1024xi32, #blocked>
120+
%94 = arith.addi %5, %93 : tensor<1024xi32, #blocked>
121+
%95 = arith.addi %94, %70 : tensor<1024xi32, #blocked>
122+
%96 = arith.addi %95, %72 : tensor<1024xi32, #blocked>
123+
%97 = tt.addptr %74, %96 : tensor<1024x!tt.ptr<i8>, #blocked>, tensor<1024xi32, #blocked>
124+
%98 = ttg.convert_layout %97 : tensor<1024x!tt.ptr<i8>, #blocked> -> tensor<1024x!tt.ptr<i8>, #blocked1>
125+
%99 = tt.load %98 : tensor<1024x!tt.ptr<i8>, #blocked1>
126+
%100 = ttg.convert_layout %99 : tensor<1024xi8, #blocked1> -> tensor<1024xi8, #blocked>
127+
%101 = tt.addptr %79, %96 : tensor<1024x!tt.ptr<f16>, #blocked>, tensor<1024xi32, #blocked>
128+
%102 = ttg.convert_layout %101 : tensor<1024x!tt.ptr<f16>, #blocked> -> tensor<1024x!tt.ptr<f16>, #blocked1>
129+
%103 = tt.load %102 : tensor<1024x!tt.ptr<f16>, #blocked1>
130+
%104 = ttg.convert_layout %103 : tensor<1024xf16, #blocked1> -> tensor<1024xf16, #blocked>
131+
%105 = arith.extf %104 : tensor<1024xf16, #blocked> to tensor<1024xf32, #blocked>
132+
%106 = arith.addi %52, %cst_7 : tensor<1024xi32, #blocked>
133+
%107 = arith.cmpi sle, %106, %62 : tensor<1024xi32, #blocked>
134+
%108 = arith.extui %107 : tensor<1024xi1, #blocked> to tensor<1024xi32, #blocked>
135+
%109 = arith.muli %106, %108 : tensor<1024xi32, #blocked>
136+
%110 = arith.cmpi slt, %62, %106 : tensor<1024xi32, #blocked>
137+
%111 = arith.extui %110 : tensor<1024xi1, #blocked> to tensor<1024xi32, #blocked>
138+
%112 = arith.muli %62, %111 : tensor<1024xi32, #blocked>
139+
%113 = arith.addi %109, %112 : tensor<1024xi32, #blocked>
140+
%114 = arith.muli %113, %cst_5 : tensor<1024xi32, #blocked>
141+
%115 = arith.addi %41, %114 : tensor<1024xi32, #blocked>
142+
%116 = arith.addi %115, %72 : tensor<1024xi32, #blocked>
143+
%117 = tt.addptr %74, %116 : tensor<1024x!tt.ptr<i8>, #blocked>, tensor<1024xi32, #blocked>
144+
%118 = ttg.convert_layout %117 : tensor<1024x!tt.ptr<i8>, #blocked> -> tensor<1024x!tt.ptr<i8>, #blocked1>
145+
%119 = tt.load %118 : tensor<1024x!tt.ptr<i8>, #blocked1>
146+
%120 = ttg.convert_layout %119 : tensor<1024xi8, #blocked1> -> tensor<1024xi8, #blocked>
147+
%121 = tt.addptr %79, %116 : tensor<1024x!tt.ptr<f16>, #blocked>, tensor<1024xi32, #blocked>
148+
%122 = ttg.convert_layout %121 : tensor<1024x!tt.ptr<f16>, #blocked> -> tensor<1024x!tt.ptr<f16>, #blocked1>
149+
%123 = tt.load %122 : tensor<1024x!tt.ptr<f16>, #blocked1>
150+
%124 = ttg.convert_layout %123 : tensor<1024xf16, #blocked1> -> tensor<1024xf16, #blocked>
151+
%125 = arith.extf %124 : tensor<1024xf16, #blocked> to tensor<1024xf32, #blocked>
152+
%126 = arith.addi %94, %114 : tensor<1024xi32, #blocked>
153+
%127 = arith.addi %126, %72 : tensor<1024xi32, #blocked>
154+
%128 = tt.addptr %74, %127 : tensor<1024x!tt.ptr<i8>, #blocked>, tensor<1024xi32, #blocked>
155+
%129 = ttg.convert_layout %128 : tensor<1024x!tt.ptr<i8>, #blocked> -> tensor<1024x!tt.ptr<i8>, #blocked1>
156+
%130 = tt.load %129 : tensor<1024x!tt.ptr<i8>, #blocked1>
157+
%131 = ttg.convert_layout %130 : tensor<1024xi8, #blocked1> -> tensor<1024xi8, #blocked>
158+
%132 = tt.addptr %79, %127 : tensor<1024x!tt.ptr<f16>, #blocked>, tensor<1024xi32, #blocked>
159+
%133 = ttg.convert_layout %132 : tensor<1024x!tt.ptr<f16>, #blocked> -> tensor<1024x!tt.ptr<f16>, #blocked1>
160+
%134 = tt.load %133 : tensor<1024x!tt.ptr<f16>, #blocked1>
161+
%135 = ttg.convert_layout %134 : tensor<1024xf16, #blocked1> -> tensor<1024xf16, #blocked>
162+
%136 = arith.extf %135 : tensor<1024xf16, #blocked> to tensor<1024xf32, #blocked>
163+
%137 = arith.extsi %78 : tensor<1024xi8, #blocked> to tensor<1024xi32, #blocked>
164+
%138 = arith.addi %137, %cst_3 : tensor<1024xi32, #blocked>
165+
%139 = arith.cmpi slt, %137, %cst_9 : tensor<1024xi32, #blocked>
166+
%140 = arith.select %139, %138, %137 : tensor<1024xi1, #blocked>, tensor<1024xi32, #blocked>
167+
%141 = arith.cmpi sge, %140, %cst_9 : tensor<1024xi32, #blocked>
168+
%142 = arith.cmpi slt, %140, %cst_3 : tensor<1024xi32, #blocked>
169+
%143 = arith.andi %141, %142 : tensor<1024xi1, #blocked>
170+
tt.assert %143, "index out of bounds: 0 <= tmp4 < 9" : tensor<1024xi1, #blocked>
171+
%144 = arith.muli %39, %cst_8 : tensor<1024xi32, #blocked>
172+
%145 = arith.addi %140, %144 : tensor<1024xi32, #blocked>
173+
%146 = arith.divsi %140, %cst_2 : tensor<1024xi32, #blocked>
174+
%147 = arith.muli %146, %cst_1 : tensor<1024xi32, #blocked>
175+
%148 = arith.addi %145, %147 : tensor<1024xi32, #blocked>
176+
%149 = arith.muli %69, %cst_0 : tensor<1024xi32, #blocked>
177+
%150 = arith.addi %148, %149 : tensor<1024xi32, #blocked>
178+
%151 = arith.cmpi eq, %150, %11 : tensor<1024xi32, #blocked>
179+
%152 = arith.select %151, %84, %cst : tensor<1024xi1, #blocked>, tensor<1024xf32, #blocked>
180+
%153 = arith.extsi %100 : tensor<1024xi8, #blocked> to tensor<1024xi32, #blocked>
181+
%154 = arith.addi %153, %cst_3 : tensor<1024xi32, #blocked>
182+
%155 = arith.cmpi slt, %153, %cst_9 : tensor<1024xi32, #blocked>
183+
%156 = arith.select %155, %154, %153 : tensor<1024xi1, #blocked>, tensor<1024xi32, #blocked>
184+
%157 = arith.cmpi sge, %156, %cst_9 : tensor<1024xi32, #blocked>
185+
%158 = arith.cmpi slt, %156, %cst_3 : tensor<1024xi32, #blocked>
186+
%159 = arith.andi %157, %158 : tensor<1024xi1, #blocked>
187+
tt.assert %159, "index out of bounds: 0 <= tmp15 < 9" : tensor<1024xi1, #blocked>
188+
%160 = arith.muli %92, %cst_8 : tensor<1024xi32, #blocked>
189+
%161 = arith.addi %156, %160 : tensor<1024xi32, #blocked>
190+
%162 = arith.divsi %156, %cst_2 : tensor<1024xi32, #blocked>
191+
%163 = arith.muli %162, %cst_1 : tensor<1024xi32, #blocked>
192+
%164 = arith.addi %161, %163 : tensor<1024xi32, #blocked>
193+
%165 = arith.addi %164, %149 : tensor<1024xi32, #blocked>
194+
%166 = arith.cmpi eq, %165, %11 : tensor<1024xi32, #blocked>
195+
%167 = arith.cmpi slt, %52, %61 : tensor<1024xi32, #blocked>
196+
%168 = arith.cmpi slt, %85, %31 : tensor<1024xi32, #blocked>
197+
%169 = arith.andi %167, %168 : tensor<1024xi1, #blocked>
198+
%170 = arith.andi %169, %166 : tensor<1024xi1, #blocked>
199+
%171 = arith.addf %152, %105 : tensor<1024xf32, #blocked>
200+
%172 = arith.select %170, %171, %152 : tensor<1024xi1, #blocked>, tensor<1024xf32, #blocked>
201+
%173 = arith.extsi %120 : tensor<1024xi8, #blocked> to tensor<1024xi32, #blocked>
202+
%174 = arith.addi %173, %cst_3 : tensor<1024xi32, #blocked>
203+
%175 = arith.cmpi slt, %173, %cst_9 : tensor<1024xi32, #blocked>
204+
%176 = arith.select %175, %174, %173 : tensor<1024xi1, #blocked>, tensor<1024xi32, #blocked>
205+
%177 = arith.cmpi sge, %176, %cst_9 : tensor<1024xi32, #blocked>
206+
%178 = arith.cmpi slt, %176, %cst_3 : tensor<1024xi32, #blocked>
207+
%179 = arith.andi %177, %178 : tensor<1024xi1, #blocked>
208+
tt.assert %179, "index out of bounds: 0 <= tmp33 < 9" : tensor<1024xi1, #blocked>
209+
%180 = arith.addi %176, %144 : tensor<1024xi32, #blocked>
210+
%181 = arith.divsi %176, %cst_2 : tensor<1024xi32, #blocked>
211+
%182 = arith.muli %181, %cst_1 : tensor<1024xi32, #blocked>
212+
%183 = arith.addi %180, %182 : tensor<1024xi32, #blocked>
213+
%184 = arith.muli %113, %cst_0 : tensor<1024xi32, #blocked>
214+
%185 = arith.addi %183, %184 : tensor<1024xi32, #blocked>
215+
%186 = arith.cmpi eq, %185, %11 : tensor<1024xi32, #blocked>
216+
%187 = arith.cmpi slt, %106, %61 : tensor<1024xi32, #blocked>
217+
%188 = arith.cmpi slt, %22, %31 : tensor<1024xi32, #blocked>
218+
%189 = arith.andi %187, %188 : tensor<1024xi1, #blocked>
219+
%190 = arith.andi %189, %186 : tensor<1024xi1, #blocked>
220+
%191 = arith.addf %172, %125 : tensor<1024xf32, #blocked>
221+
%192 = arith.select %190, %191, %172 : tensor<1024xi1, #blocked>, tensor<1024xf32, #blocked>
222+
%193 = arith.extsi %131 : tensor<1024xi8, #blocked> to tensor<1024xi32, #blocked>
223+
%194 = arith.addi %193, %cst_3 : tensor<1024xi32, #blocked>
224+
%195 = arith.cmpi slt, %193, %cst_9 : tensor<1024xi32, #blocked>
225+
%196 = arith.select %195, %194, %193 : tensor<1024xi1, #blocked>, tensor<1024xi32, #blocked>
226+
%197 = arith.cmpi sge, %196, %cst_9 : tensor<1024xi32, #blocked>
227+
%198 = arith.cmpi slt, %196, %cst_3 : tensor<1024xi32, #blocked>
228+
%199 = arith.andi %197, %198 : tensor<1024xi1, #blocked>
229+
tt.assert %199, "index out of bounds: 0 <= tmp49 < 9" : tensor<1024xi1, #blocked>
230+
%200 = arith.addi %196, %160 : tensor<1024xi32, #blocked>
231+
%201 = arith.divsi %196, %cst_2 : tensor<1024xi32, #blocked>
232+
%202 = arith.muli %201, %cst_1 : tensor<1024xi32, #blocked>
233+
%203 = arith.addi %200, %202 : tensor<1024xi32, #blocked>
234+
%204 = arith.addi %203, %184 : tensor<1024xi32, #blocked>
235+
%205 = arith.cmpi eq, %204, %11 : tensor<1024xi32, #blocked>
236+
%206 = arith.andi %187, %168 : tensor<1024xi1, #blocked>
237+
%207 = arith.andi %206, %205 : tensor<1024xi1, #blocked>
238+
%208 = arith.addf %192, %136 : tensor<1024xf32, #blocked>
239+
%209 = arith.select %207, %208, %192 : tensor<1024xi1, #blocked>, tensor<1024xf32, #blocked>
240+
%210 = tt.splat %out_ptr0 : !tt.ptr<f16> -> tensor<1024x!tt.ptr<f16>, #blocked>
241+
%211 = tt.addptr %210, %4 : tensor<1024x!tt.ptr<f16>, #blocked>, tensor<1024xi32, #blocked>
242+
%212 = arith.truncf %209 : tensor<1024xf32, #blocked> to tensor<1024xf16, #blocked>
243+
%213 = ttg.convert_layout %211 : tensor<1024x!tt.ptr<f16>, #blocked> -> tensor<1024x!tt.ptr<f16>, #blocked1>
244+
%214 = ttg.convert_layout %212 : tensor<1024xf16, #blocked> -> tensor<1024xf16, #blocked1>
245+
tt.store %213, %214 : tensor<1024x!tt.ptr<f16>, #blocked1>
246+
tt.return
247+
}
248+
}

third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1560,7 +1560,9 @@ void LayoutRematerialization::backwardRematerialization(
15601560
// We measure costs in standardised milli-SM-cycles. The smem load
15611561
// and store each cost 8 * convertLayoutBytes, and then we double
15621562
// it to account for extra cost due to synchronisation.
1563-
int64_t convertLayoutCost = 32 * convertLayoutBytes;
1563+
// FIXME: measure cost of smem load/store and synchronisation on Intel GPUs,
1564+
// and refine this model further. (#5476)
1565+
int64_t convertLayoutCost = 32 * convertLayoutBytes * 2;
15641566
int64_t rematerialisationCost = 0;
15651567

15661568
// Evaluate single-use status for every operation in slice

0 commit comments

Comments
 (0)