diff --git a/lib/gc/Transforms/GPU/LinalgToXeGPU.cpp b/lib/gc/Transforms/GPU/LinalgToXeGPU.cpp index 86b45466..3f4cb514 100644 --- a/lib/gc/Transforms/GPU/LinalgToXeGPU.cpp +++ b/lib/gc/Transforms/GPU/LinalgToXeGPU.cpp @@ -706,13 +706,12 @@ static SmallVector createNdDescriptorTiles( Value newRowOffs = rewriter.create(loc, i); for (int j = 0; j < loadShape[1]; j += descTile[1] * arrayLength) { Value newColOffs = rewriter.create(loc, j); - if (transpose) { - std::swap(newRowOffs, newColOffs); - } auto tile = rewriter .create( loc, descType, rootTile, - /*offsets=*/ValueRange{newRowOffs, newColOffs}, + /*offsets=*/ + transpose ? ValueRange{newColOffs, newRowOffs} + : ValueRange{newRowOffs, newColOffs}, SmallVector{ShapedType::kDynamic, ShapedType::kDynamic}) .getResult(); diff --git a/test/mlir/test/gc/Transforms/GPU/linalg-to-xegpu-dpas-transpose.mlir b/test/mlir/test/gc/Transforms/GPU/linalg-to-xegpu-dpas-transpose.mlir index be38afae..fa7ce025 100644 --- a/test/mlir/test/gc/Transforms/GPU/linalg-to-xegpu-dpas-transpose.mlir +++ b/test/mlir/test/gc/Transforms/GPU/linalg-to-xegpu-dpas-transpose.mlir @@ -3,13 +3,14 @@ module { func.func @matmul_transpose_b(%arg0: memref<1024x1024xf16>, %arg1: memref<1024x1024xf16>, %arg2: memref<1024x1024xf16>) { %c0 = arith.constant 0 : index - %c32 = arith.constant 32 : index + %c16 = arith.constant 16 : index + %c64 = arith.constant 64 : index %c1024 = arith.constant 1024 : index - scf.parallel (%arg3, %arg4) = (%c0, %c0) to (%c1024, %c1024) step (%c32, %c32) { - %subview_0 = memref.subview %arg2[%arg3, %arg4] [32, 32] [1, 1] : memref<1024x1024xf16> to memref<32x32xf16, strided<[1024, 1], offset: ?>> - %subview_1 = memref.subview %arg0[%arg3, 0] [32, 1024] [1, 1] : memref<1024x1024xf16> to memref<32x1024xf16, strided<[1024, 1], offset: ?>> - %subview_2 = memref.subview %arg1[%arg4, 0] [32, 1024] [1, 1] : memref<1024x1024xf16> to memref<32x1024xf16, strided<[1024, 1], offset: ?>> - linalg.matmul_transpose_b ins(%subview_1, %subview_2 : memref<32x1024xf16, strided<[1024, 1], offset: ?>>, memref<32x1024xf16, strided<[1024, 1], offset: ?>>) outs(%subview_0 : memref<32x32xf16, strided<[1024, 1], offset: ?>>) + scf.parallel (%arg3, %arg4) = (%c0, %c0) to (%c1024, %c1024) step (%c16, %c64) { + %subview_0 = memref.subview %arg2[%arg3, %arg4] [16, 64] [1, 1] : memref<1024x1024xf16> to memref<16x64xf16, strided<[1024, 1], offset: ?>> + %subview_1 = memref.subview %arg0[%arg3, 0] [16, 1024] [1, 1] : memref<1024x1024xf16> to memref<16x1024xf16, strided<[1024, 1], offset: ?>> + %subview_2 = memref.subview %arg1[%arg4, 0] [64, 1024] [1, 1] : memref<1024x1024xf16> to memref<64x1024xf16, strided<[1024, 1], offset: ?>> + linalg.matmul_transpose_b ins(%subview_1, %subview_2 : memref<16x1024xf16, strided<[1024, 1], offset: ?>>, memref<64x1024xf16, strided<[1024, 1], offset: ?>>) outs(%subview_0 : memref<16x64xf16, strided<[1024, 1], offset: ?>>) scf.reduce } return @@ -19,7 +20,7 @@ module { // CHECK-LABEL: func.func @matmul_transpose_b // CHECK-SAME: %[[Ap:.+]]: memref<1024x1024xf16>, %[[Bp:.+]]: memref<1024x1024xf16>, %[[Cp:.+]]: memref<1024x1024xf16> -// CHECK: scf.parallel (%[[iter1:.+]], %[[iter2:.+]]) = (%c0, %c0) to (%c1024, %c1024) step (%c32, %c32) { +// CHECK: scf.parallel (%[[iter1:.+]], %[[iter2:.+]]) = (%c0, %c0) to (%c1024, %c1024) step (%c16, %c64) { // CHECK: %[[C:.+]] = memref.subview %[[Cp]][%[[iter1]], %[[iter2]]] {{.*}} // CHECK: %[[A:.+]] = memref.subview %[[Ap]][%[[iter1]], 0] {{.*}} // CHECK: %[[B:.+]] = memref.subview %[[Bp]][%[[iter2]], 0] {{.*}} @@ -43,9 +44,11 @@ module { // CHECK: %[[rootB:.+]] = xegpu.create_nd_tdesc %[[B]] // CHECK: %[[tB:.+]] = xegpu.update_nd_offset %[[rootB]], [%c0, %c0] // CHECK: %[[tB1:.+]] = xegpu.update_nd_offset %[[rootB]], [%c16, %c0] +// CHECK: %[[tB2:.+]] = xegpu.update_nd_offset %[[rootB]], [%c32, %c0] +// CHECK: %[[tB3:.+]] = xegpu.update_nd_offset %[[rootB]], [%c48, %c0] // Create DPAS computation loop over tiled reduction dimension. -// CHECK: %[[res:.+]]:11 = scf.for{{.*}}%c0 to %c1024 step %c16 +// CHECK: %[[res:.+]]:13 = scf.for{{.*}}%c0 to %c1024 step %c16 // CHECK-SAME: iter_args(%[[acc:.+]] = %[[vC_f32]],{{.*}}%[[iterA:.+]] = %[[tA]],{{.*}}%[[iterB:.+]] = %[[tB]],{{.*}}%[[iterB1:.+]] = %[[tB1]] // CHECK-SAME: { @@ -66,10 +69,10 @@ module { // Extract DPAS-sized chunks from larger loaded tile A. // Tile B is already in the correct shape. -// CHECK: %[[vA_flat:.+]] = vector.shape_cast %[[vA]] : vector<32x16xf16> to vector<512xf16> -// CHECK: %[[vA_dpas_flat:.+]] = vector.extract_strided_slice{{.*}}: vector<512xf16> to vector<128xf16> +// CHECK: %[[vA_flat:.+]] = vector.shape_cast %[[vA]] : vector<16x16xf16> to vector<256xf16> +// CHECK: %[[vA_dpas_flat:.+]] = vector.extract_strided_slice{{.*}}: vector<256xf16> to vector<128xf16> // CHECK: %[[vA_dpas:.+]] = vector.shape_cast %[[vA_dpas_flat]] : vector<128xf16> to vector<8x16xf16> -// CHECK-COUNT-3: vector.extract_strided_slice +// CHECK-COUNT-1: vector.extract_strided_slice // Perform DPAS computation. // CHECK: %[[dpas:.+]] = xegpu.dpas %[[vA_dpas]], %[[vB]], %[[acc]]