@@ -4,45 +4,25 @@ gpu.module @create_nd_tdesc {
44 // CHECK-LABEL: gpu.func @create_nd_tdesc
55 // CHECK-SAME: %[[ARG0:.*]]: memref<16x32xf32, 1>, %[[ARG1:.*]]: ui64,
66 // CHECK-SAME: %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index, %[[ARG5:.*]]: index, %[[ARG6:.*]]: index, %[[ARG7:.*]]: index
7+ // CHECK-SAME: %[[ARG8:.*]]: memref<?x?xf16>) kernel {
78 gpu.func @create_nd_tdesc (%src: memref <16 x32 xf32 , 1 >, %ptr: ui64 , %shape1: index , %shape2: index ,
8- %stride1: index , %stride2: index , %offset1: index , %offset2: index ) kernel {
9- // CHECK: %[[VAR0:.*]] = index.castu %[[ARG1]] : ui64 to index
10- // CHECK: %[[BASE_ADDR:.*]] = arith.index_castui %[[VAR0]] : index to i64
11- // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<8xi32>
12- // CHECK: %[[OFFSET_W:.*]] = arith.constant 0 : i32
13- // CHECK: %[[OFFSET_H:.*]] = arith.constant 0 : i32
14- // CHECK: %[[SHAPE_W:.*]] = arith.index_cast %[[ARG3]] : index to i32
15- // CHECK: %[[SHAPE_H:.*]] = arith.index_cast %[[ARG2]] : index to i32
16- // CHECK: %[[VAR6:.*]] = vector.bitcast %[[CST]] : vector<8xi32> to vector<4xi64>
17- // CHECK: %[[VAR7:.*]] = vector.insert %[[BASE_ADDR]], %[[VAR6]] [0] : i64 into vector<4xi64>
18- // CHECK: %[[VAR8:.*]] = vector.bitcast %[[VAR7]] : vector<4xi64> to vector<8xi32>
19- // CHECK: %[[VAR9:.*]] = vector.insert %[[SHAPE_W]], %[[VAR8]] [2] : i32 into vector<8xi32>
20- // CHECK: %[[VAR10:.*]] = vector.insert %[[SHAPE_H]], %[[VAR9]] [3] : i32 into vector<8xi32>
21- // CHECK: %[[VAR11:.*]] = vector.insert %[[OFFSET_W]], %[[VAR10]] [4] : i32 into vector<8xi32>
22- // CHECK: %[[VAR12:.*]] = vector.insert %[[OFFSET_H]], %[[VAR11]] [5] : i32 into vector<8xi32>
9+ %stride1: index , %stride2: index , %offset1: index , %offset2: index , %dyn: memref <?x?xf16 >) kernel {
10+ // Optimized away
2311 %ptr_tdesc = xegpu.create_nd_tdesc %ptr , shape :[%shape1 , %shape2 ], strides :[%stride1 , %stride2 ]
2412 : ui64 -> !xegpu.tensor_desc <8 x16 xf32 >
25-
26- // CHECK: %[[MEMSPACECAST:.*]] = memref.memory_space_cast %[[ARG0]] : memref<16x32xf32, 1> to memref<16x32xf32>
13+ // CHECK-NEXT: %[[MEMSPACECAST:.*]] = memref.memory_space_cast %[[ARG0]] : memref<16x32xf32, 1> to memref<16x32xf32>
2714 %srcce = memref.memory_space_cast %src : memref <16 x32 xf32 , 1 > to memref <16 x32 xf32 >
28-
29- // CHECK: %[[CST_1:.*]] = arith.constant dense<0> : vector<8xi32>
30- // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[MEMSPACECAST]] : memref<16x32xf32> -> index
31- // CHECK: %[[OFFSET_W2:.*]] = arith.constant 0 : i32
32- // CHECK: %[[OFFSET_H2:.*]] = arith.constant 0 : i32
33- // CHECK: %[[C32_I64:.*]] = arith.constant 32 : i64
34- // CHECK: %[[SHAPE_W2:.*]] = arith.trunci %[[C32_I64]] : i64 to i32
35- // CHECK: %[[C16_I64:.*]] = arith.constant 16 : i64
36- // CHECK: %[[SHAPE_H2:.*]] = arith.trunci %[[C16_I64]] : i64 to i32
37- // CHECK: %[[BASE_ADDR2:.*]] = arith.index_castui %[[INTPTR]] : index to i64
38- // CHECK: %[[VAR14:.*]] = vector.bitcast %[[CST_1]] : vector<8xi32> to vector<4xi64>
39- // CHECK: %[[VAR15:.*]] = vector.insert %[[BASE_ADDR2]], %[[VAR14]] [0] : i64 into vector<4xi64>
40- // CHECK: %[[VAR16:.*]] = vector.bitcast %[[VAR15]] : vector<4xi64> to vector<8xi32>
41- // CHECK: %[[VAR17:.*]] = vector.insert %[[SHAPE_W2]], %[[VAR16]] [2] : i32 into vector<8xi32>
42- // CHECK: %[[VAR18:.*]] = vector.insert %[[SHAPE_H2]], %[[VAR17]] [3] : i32 into vector<8xi32>
43- // CHECK: %[[VAR19:.*]] = vector.insert %[[OFFSET_W2]], %[[VAR18]] [4] : i32 into vector<8xi32>
44- // CHECK: %[[PAYLOAD:.*]] = vector.insert %[[OFFSET_H2]], %[[VAR19]] [5] : i32 into vector<8xi32>
15+ // Optimized away
4516 %src_tdesc = xegpu.create_nd_tdesc %srcce : memref <16 x32 xf32 > -> !xegpu.tensor_desc <8 x16 xf32 >
17+ // CHECK-NEXT: %c1 = arith.constant 1 : index
18+ %c1 = arith.constant 1 : index
19+ // CHECK-NEXT: %c64 = arith.constant 64 : index
20+ %size_x = arith.constant 64 : index
21+ // CHECK-NEXT: %c16 = arith.constant 16 : index
22+ %BLOCK_DMODEL = arith.constant 16 : index
23+ // Optimized away
24+ %dyn_tdesc = xegpu.create_nd_tdesc %dyn , shape : [%size_x , %BLOCK_DMODEL ], strides : [%BLOCK_DMODEL , %c1 ] : memref <?x?xf16 > -> !xegpu.tensor_desc <16 x16 xf16 >
25+ // CHECK-NEXT: gpu.return
4626 gpu.return
4727 }
4828}
0 commit comments