@@ -147,4 +147,53 @@ gpu.module @test_round_robin_assignment {
147147 }
148148 gpu.return
149149 }
150+
151+ gpu.func @test_scf_if (%arg0: memref <1024 xf32 >, %arg1: memref <1024 xf32 >) {
152+ %c10 = arith.constant 10 : index
153+ %0 = gpu.subgroup_id : index
154+ %1 = xegpu.create_nd_tdesc %arg0 [0 ] : memref <1024 xf32 > -> !xegpu.tensor_desc <256 xf32 , #xegpu.layout <sg_layout = [8 ], sg_data = [16 ]>>
155+ %2 = xegpu.create_nd_tdesc %arg1 [0 ] : memref <1024 xf32 > -> !xegpu.tensor_desc <256 xf32 , #xegpu.layout <sg_layout = [8 ], sg_data = [16 ]>>
156+ %3 = arith.cmpi eq , %0 , %c10 : index
157+ // CHECK-LABEL: scf.if
158+ // CHECK-SAME: (vector<16xf32>, vector<16xf32>)
159+ %4 = scf.if %3 -> (vector <256 xf32 >) {
160+ %5 = xegpu.load_nd %1 : !xegpu.tensor_desc <256 xf32 , #xegpu.layout <sg_layout = [8 ], sg_data = [16 ]>> -> vector <256 xf32 >
161+ // CHECK-LABEL: scf.yield
162+ // CHECK-SAME: vector<16xf32>, vector<16xf32>
163+ scf.yield %5 : vector <256 xf32 >
164+ } else {
165+ %5 = xegpu.load_nd %2 : !xegpu.tensor_desc <256 xf32 , #xegpu.layout <sg_layout = [8 ], sg_data = [16 ]>> -> vector <256 xf32 >
166+ // CHECK-LABEL: scf.yield
167+ // CHECK-SAME: vector<16xf32>, vector<16xf32>
168+ scf.yield %5 : vector <256 xf32 >
169+ } {layout_result_0 = #xegpu.layout <sg_layout = [8 ], sg_data = [16 ]>}
170+ xegpu.store_nd %4 , %1 : vector <256 xf32 >, !xegpu.tensor_desc <256 xf32 , #xegpu.layout <sg_layout = [8 ], sg_data = [16 ]>>
171+ gpu.return
172+ }
173+
174+ gpu.func @test_scf_if_tensor_desc (%arg0: memref <1024 xf32 >, %arg1: memref <1024 xf32 >) {
175+ %c10 = arith.constant 10 : index
176+ %id = gpu.subgroup_id : index
177+
178+ %t = xegpu.create_nd_tdesc %arg0 [0 ] : memref <1024 xf32 > -> !xegpu.tensor_desc <256 xf32 , #xegpu.layout <sg_layout = [8 ], sg_data = [16 ]>>
179+ %d = xegpu.load_nd %t : !xegpu.tensor_desc <256 xf32 , #xegpu.layout <sg_layout = [8 ], sg_data = [16 ]>> -> vector <256 xf32 >
180+
181+ %0 = arith.cmpi eq , %id , %c10 : index
182+ // CHECK-LABEL: scf.if
183+ // CHECK-SAME: (!xegpu.tensor_desc<16xf32>, !xegpu.tensor_desc<16xf32>)
184+ %1 = scf.if %0 -> (!xegpu.tensor_desc <256 xf32 , #xegpu.layout <sg_layout = [8 ], sg_data = [16 ]>>) {
185+ %2 = xegpu.create_nd_tdesc %arg0 [0 ] : memref <1024 xf32 > -> !xegpu.tensor_desc <256 xf32 , #xegpu.layout <sg_layout = [8 ], sg_data = [16 ]>>
186+ // CHECK-LABEL: scf.yield
187+ // CHECK-SAME: !xegpu.tensor_desc<16xf32>, !xegpu.tensor_desc<16xf32>
188+ scf.yield %2 : !xegpu.tensor_desc <256 xf32 , #xegpu.layout <sg_layout = [8 ], sg_data = [16 ]>>
189+ } else {
190+ %3 = xegpu.create_nd_tdesc %arg1 [0 ] : memref <1024 xf32 > -> !xegpu.tensor_desc <256 xf32 , #xegpu.layout <sg_layout = [8 ], sg_data = [16 ]>>
191+ // CHECK-LABEL: scf.yield
192+ // CHECK-SAME: !xegpu.tensor_desc<16xf32>, !xegpu.tensor_desc<16xf32>
193+ scf.yield %3 : !xegpu.tensor_desc <256 xf32 , #xegpu.layout <sg_layout = [8 ], sg_data = [16 ]>>
194+ }
195+ xegpu.store_nd %d , %1 : vector <256 xf32 >, !xegpu.tensor_desc <256 xf32 , #xegpu.layout <sg_layout = [8 ], sg_data = [16 ]>>
196+ gpu.return
197+ }
198+
150199}
0 commit comments