@@ -57,36 +57,31 @@ gpu.module @test_round_robin_assignment {
5757 }
5858
5959 // CHECK-LABEL: dpas
60- // CHECK-SAME: (%[[ARG_0:.*]]: memref<256x128xf32 >, %[[ARG_1:.*]]: memref<128x256xf32>, %[[ARG_2:.*]]: memref<256x256xf32 >)
61- gpu.func @dpas (%a: memref <256 x 128 x f32 >, %b: memref <128 x 256 x f32 >, %c: memref < 256 x 256 x f32 >) {
62- // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_0]][%{{.*}}, %{{.*}}] : memref<256x128xf32 >
63- // CHECK-SAME-COUNT-4: -> !xegpu.tensor_desc<16x16xf32 , #xegpu.layout<lane_layout = [8, 4], lane_data = [1, 1]>>
60+ // CHECK-SAME: (%[[ARG_0:.*]]: memref<256x128xf16 >, %[[ARG_1:.*]]: memref<128x256xf16 >)
61+ gpu.func @dpas (%a: memref <256 x 128 x f16 >, %b: memref <128 x 256 x f16 >) {
62+ // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_0]][%{{.*}}, %{{.*}}] : memref<256x128xf16 >
63+ // CHECK-SAME-COUNT-4: -> !xegpu.tensor_desc<16x16xf16 , #xegpu.layout<lane_layout = [8, 4], lane_data = [1, 1]>>
6464 // CHECK-NOT: xegpu.create_nd_tdesc
65- // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_1]][%{{.*}}, %{{.*}}] : memref<128x256xf32>
66- // CHECK-SAME-COUNT-4: -> !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [4, 8], lane_data = [1, 1]>>
67- // CHECK-NOT: xegpu.create_nd_tdesc
68- // CHECK-COUNT-4: xegpu.create_nd_tdesc %{{.*}}[%{{.*}}, %{{.*}}] : memref<256x256xf32>
69- // CHECK-SAME-COUNT-4: -> !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [8, 8], lane_data = [1, 1]>>
65+ // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_1]][%{{.*}}, %{{.*}}] : memref<128x256xf16>
66+ // CHECK-SAME-COUNT-4: -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [4, 8], lane_data = [1, 1]>>
7067 // CHECK-NOT: xegpu.create_nd_tdesc
7168 // CHECK-COUNT-16: xegpu.dpas %{{.*}}, %{{.*}}
72- // CHECK-SAME-COUNT-16: {layout = #xegpu.layout<lane_layout = [8, 8 ], lane_data = [1, 1]>}
73- // CHECK-SAME-COUNT-16: : vector<16x16xf32 >, vector<16x16xf32 > -> vector<16x16xf32>
69+ // CHECK-SAME-COUNT-16: {layout = #xegpu.layout<lane_layout = [1, 16 ], lane_data = [1, 1]>}
70+ // CHECK-SAME-COUNT-16: : vector<16x16xf16 >, vector<16x16xf16 > -> vector<16x16xf32>
7471 // CHECK-NOT: xegpu.dpas
75- %tdesc_a = xegpu.create_nd_tdesc %a [0 , 0 ] : memref <256 x 128 x f32 >
76- -> !xegpu.tensor_desc <256 x 128 x f32 , #xegpu.layout <sg_layout = [8 , 4 ], sg_data = [16 , 16 ], lane_layout = [8 , 4 ], lane_data = [1 , 1 ]>>
72+ %tdesc_a = xegpu.create_nd_tdesc %a [0 , 0 ] : memref <256 x 128 x f16 >
73+ -> !xegpu.tensor_desc <256 x 128 x f16 , #xegpu.layout <sg_layout = [8 , 4 ], sg_data = [16 , 16 ], lane_layout = [1 , 16 ], lane_data = [1 , 1 ]>>
7774 %load_a = xegpu.load_nd %tdesc_a
78- : !xegpu.tensor_desc <256 x 128 x f32 , #xegpu.layout <sg_layout = [8 , 4 ], sg_data = [16 , 16 ], lane_layout = [8 , 4 ], lane_data = [1 , 1 ]>>
79- -> vector <256 x 128 x f32 >
80- %tdesc_b = xegpu.create_nd_tdesc %b [0 , 0 ] : memref <128 x 256 x f32 >
81- -> !xegpu.tensor_desc <128 x 256 x f32 , #xegpu.layout <sg_layout = [4 , 8 ], sg_data = [16 , 16 ], lane_layout = [4 , 8 ], lane_data = [1 , 1 ]>>
75+ : !xegpu.tensor_desc <256 x 128 x f16 , #xegpu.layout <sg_layout = [8 , 4 ], sg_data = [16 , 16 ], lane_layout = [1 , 16 ], lane_data = [1 , 1 ]>>
76+ -> vector <256 x 128 x f16 >
77+ %tdesc_b = xegpu.create_nd_tdesc %b [0 , 0 ] : memref <128 x 256 x f16 >
78+ -> !xegpu.tensor_desc <128 x 256 x f16 , #xegpu.layout <sg_layout = [4 , 8 ], sg_data = [16 , 16 ], lane_layout = [1 , 16 ], lane_data = [2 , 1 ]>>
8279 %load_b = xegpu.load_nd %tdesc_b
83- : !xegpu.tensor_desc <128 x256 xf32 , #xegpu.layout <sg_layout = [4 , 8 ], sg_data = [16 , 16 ], lane_layout = [4 , 8 ], lane_data = [1 , 1 ]>>
84- -> vector <128 x256 xf32 >
85- %tdesc_c = xegpu.create_nd_tdesc %c [0 , 0 ] : memref <256 x256 xf32 >
86- -> !xegpu.tensor_desc <256 x256 xf32 , #xegpu.layout <sg_layout = [8 , 8 ], sg_data = [16 , 16 ], lane_layout = [8 , 8 ], lane_data = [1 , 1 ]>>
80+ : !xegpu.tensor_desc <128 x256 xf16 , #xegpu.layout <sg_layout = [4 , 8 ], sg_data = [16 , 16 ], lane_layout = [1 , 16 ], lane_data = [2 , 1 ]>>
81+ -> vector <128 x256 xf16 >
8782 %dpas = xegpu.dpas %load_a , %load_b
88- {layout_result_0 = #xegpu.layout <sg_layout = [8 , 8 ], sg_data = [16 , 16 ], lane_layout = [8 , 8 ], lane_data = [1 , 1 ]>}
89- : vector <256 x 128 x f32 >, vector <128 x 256 x f32 > -> vector <256 x256 xf32 >
83+ {layout_result_0 = #xegpu.layout <sg_layout = [8 , 8 ], sg_data = [16 , 16 ], lane_layout = [1 , 16 ], lane_data = [1 , 1 ]>}
84+ : vector <256 x 128 x f16 >, vector <128 x 256 x f16 > -> vector <256 x256 xf32 >
9085 gpu.return
9186 }
9287
0 commit comments