@@ -4,8 +4,9 @@ gpu.module @create_nd_tdesc {
44 // CHECK-LABEL: gpu.func @create_nd_tdesc
55 // CHECK-SAME: %[[ARG0:.*]]: memref<16x32xf32, 1>, %[[ARG1:.*]]: ui64,
66 // CHECK-SAME: %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index, %[[ARG5:.*]]: index, %[[ARG6:.*]]: index, %[[ARG7:.*]]: index
7+ // CHECK-SAME: %[[DYN:.*]]: memref<?x?xf16>) kernel {
78 gpu.func @create_nd_tdesc (%src: memref <16 x32 xf32 , 1 >, %ptr: ui64 , %shape1: index , %shape2: index ,
8- %stride1: index , %stride2: index , %offset1: index , %offset2: index ) kernel {
9+ %stride1: index , %stride2: index , %offset1: index , %offset2: index , %dyn: memref <?x?x f16 > ) kernel {
910 // CHECK: %[[VAR0:.*]] = index.castu %[[ARG1]] : ui64 to index
1011 // CHECK: %[[BASE_ADDR:.*]] = arith.index_castui %[[VAR0]] : index to i64
1112 // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<8xi32>
@@ -43,6 +44,28 @@ gpu.module @create_nd_tdesc {
4344 // CHECK: %[[VAR19:.*]] = vector.insert %[[OFFSET_W2]], %[[VAR18]] [4] : i32 into vector<8xi32>
4445 // CHECK: %[[PAYLOAD:.*]] = vector.insert %[[OFFSET_H2]], %[[VAR19]] [5] : i32 into vector<8xi32>
4546 %src_tdesc = xegpu.create_nd_tdesc %srcce : memref <16 x32 xf32 > -> !xegpu.tensor_desc <8 x16 xf32 >
47+
48+ // CHECK: %[[C1:.*]] = arith.constant 1 : index
49+ %c1 = arith.constant 1 : index
50+ // CHECK: %[[C64:.*]] = arith.constant 64 : index
51+ %size_x = arith.constant 64 : index
52+ // CHECK: %[[C16:.*]] = arith.constant 16 : index
53+ %BLOCK_DMODEL = arith.constant 16 : index
54+ // CHECK: %[[CST_4:.*]] = arith.constant dense<0> : vector<8xi32>
55+ // CHECK: %[[INTPTR_5:.*]] = memref.extract_aligned_pointer_as_index %[[DYN]] : memref<?x?xf16> -> index
56+ // CHECK: %[[C0_I32_6:.*]] = arith.constant 0 : i32
57+ // CHECK: %[[C0_I32_7:.*]] = arith.constant 0 : i32
58+ // CHECK: %[[VAR21:.*]] = arith.index_cast %[[C16]] : index to i32
59+ // CHECK: %[[VAR22:.*]] = arith.index_cast %[[C64]] : index to i32
60+ // CHECK: %[[VAR23:.*]] = arith.index_castui %[[INTPTR_5]] : index to i64
61+ // CHECK: %[[VAR24:.*]] = vector.bitcast %[[CST_4]] : vector<8xi32> to vector<4xi64>
62+ // CHECK: %[[VAR25:.*]] = vector.insert %[[VAR23]], %[[VAR24]] [0] : i64 into vector<4xi64>
63+ // CHECK: %[[VAR26:.*]] = vector.bitcast %[[VAR25]] : vector<4xi64> to vector<8xi32>
64+ // CHECK: %[[VAR27:.*]] = vector.insert %[[VAR21]], %[[VAR26]] [2] : i32 into vector<8xi32>
65+ // CHECK: %[[VAR28:.*]] = vector.insert %[[VAR22]], %[[VAR27]] [3] : i32 into vector<8xi32>
66+ // CHECK: %[[VAR29:.*]] = vector.insert %[[C0_I32_6]], %[[VAR28]] [4] : i32 into vector<8xi32>
67+ // CHECK: %[[VAR30:.*]] = vector.insert %[[C0_I32_7]], %[[VAR29]] [5] : i32 into vector<8xi32>
68+ %dyn_tdesc = xegpu.create_nd_tdesc %dyn , shape : [%size_x , %BLOCK_DMODEL ], strides : [%BLOCK_DMODEL , %c1 ] : memref <?x?xf16 > -> !xegpu.tensor_desc <16 x16 xf16 >
4669 gpu.return
4770 }
4871}
0 commit comments