Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions lib/gc/Transforms/GPU/LinalgToXeGPU.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -706,13 +706,12 @@ static SmallVector<Value> createNdDescriptorTiles(
Value newRowOffs = rewriter.create<arith::ConstantIndexOp>(loc, i);
for (int j = 0; j < loadShape[1]; j += descTile[1] * arrayLength) {
Value newColOffs = rewriter.create<arith::ConstantIndexOp>(loc, j);
if (transpose) {
std::swap(newRowOffs, newColOffs);
}
Comment on lines -709 to -711
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

placing std::swap inside a nested for-loop was a bad idea since it swaps the values each iteration producing non-sense offsets at the end

auto tile = rewriter
.create<xegpu::UpdateNdOffsetOp>(
loc, descType, rootTile,
/*offsets=*/ValueRange{newRowOffs, newColOffs},
/*offsets=*/
transpose ? ValueRange{newColOffs, newRowOffs}
: ValueRange{newRowOffs, newColOffs},
SmallVector<int64_t>{ShapedType::kDynamic,
ShapedType::kDynamic})
.getResult();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@
module {
func.func @matmul_transpose_b(%arg0: memref<1024x1024xf16>, %arg1: memref<1024x1024xf16>, %arg2: memref<1024x1024xf16>) {
%c0 = arith.constant 0 : index
%c32 = arith.constant 32 : index
%c16 = arith.constant 16 : index
%c64 = arith.constant 64 : index
%c1024 = arith.constant 1024 : index
scf.parallel (%arg3, %arg4) = (%c0, %c0) to (%c1024, %c1024) step (%c32, %c32) {
%subview_0 = memref.subview %arg2[%arg3, %arg4] [32, 32] [1, 1] : memref<1024x1024xf16> to memref<32x32xf16, strided<[1024, 1], offset: ?>>
%subview_1 = memref.subview %arg0[%arg3, 0] [32, 1024] [1, 1] : memref<1024x1024xf16> to memref<32x1024xf16, strided<[1024, 1], offset: ?>>
%subview_2 = memref.subview %arg1[%arg4, 0] [32, 1024] [1, 1] : memref<1024x1024xf16> to memref<32x1024xf16, strided<[1024, 1], offset: ?>>
linalg.matmul_transpose_b ins(%subview_1, %subview_2 : memref<32x1024xf16, strided<[1024, 1], offset: ?>>, memref<32x1024xf16, strided<[1024, 1], offset: ?>>) outs(%subview_0 : memref<32x32xf16, strided<[1024, 1], offset: ?>>)
scf.parallel (%arg3, %arg4) = (%c0, %c0) to (%c1024, %c1024) step (%c16, %c64) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

increased the tile size for Y axis to test the problematic case

%subview_0 = memref.subview %arg2[%arg3, %arg4] [16, 64] [1, 1] : memref<1024x1024xf16> to memref<16x64xf16, strided<[1024, 1], offset: ?>>
%subview_1 = memref.subview %arg0[%arg3, 0] [16, 1024] [1, 1] : memref<1024x1024xf16> to memref<16x1024xf16, strided<[1024, 1], offset: ?>>
%subview_2 = memref.subview %arg1[%arg4, 0] [64, 1024] [1, 1] : memref<1024x1024xf16> to memref<64x1024xf16, strided<[1024, 1], offset: ?>>
linalg.matmul_transpose_b ins(%subview_1, %subview_2 : memref<16x1024xf16, strided<[1024, 1], offset: ?>>, memref<64x1024xf16, strided<[1024, 1], offset: ?>>) outs(%subview_0 : memref<16x64xf16, strided<[1024, 1], offset: ?>>)
scf.reduce
}
return
Expand All @@ -19,7 +20,7 @@ module {
// CHECK-LABEL: func.func @matmul_transpose_b
// CHECK-SAME: %[[Ap:.+]]: memref<1024x1024xf16>, %[[Bp:.+]]: memref<1024x1024xf16>, %[[Cp:.+]]: memref<1024x1024xf16>

// CHECK: scf.parallel (%[[iter1:.+]], %[[iter2:.+]]) = (%c0, %c0) to (%c1024, %c1024) step (%c32, %c32) {
// CHECK: scf.parallel (%[[iter1:.+]], %[[iter2:.+]]) = (%c0, %c0) to (%c1024, %c1024) step (%c16, %c64) {
// CHECK: %[[C:.+]] = memref.subview %[[Cp]][%[[iter1]], %[[iter2]]] {{.*}}
// CHECK: %[[A:.+]] = memref.subview %[[Ap]][%[[iter1]], 0] {{.*}}
// CHECK: %[[B:.+]] = memref.subview %[[Bp]][%[[iter2]], 0] {{.*}}
Expand All @@ -43,9 +44,11 @@ module {
// CHECK: %[[rootB:.+]] = xegpu.create_nd_tdesc %[[B]]
// CHECK: %[[tB:.+]] = xegpu.update_nd_offset %[[rootB]], [%c0, %c0]
// CHECK: %[[tB1:.+]] = xegpu.update_nd_offset %[[rootB]], [%c16, %c0]
// CHECK: %[[tB2:.+]] = xegpu.update_nd_offset %[[rootB]], [%c32, %c0]
// CHECK: %[[tB3:.+]] = xegpu.update_nd_offset %[[rootB]], [%c48, %c0]
Comment on lines 45 to +48
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it used to do something like:

xegpu.update_nd_offset %[[rootB]], [%c0, %c0]
xegpu.update_nd_offset %[[rootB]], [%c16, %c0]
xegpu.update_nd_offset %[[rootB]], [%c32, %c16]
xegpu.update_nd_offset %[[rootB]], [%c16, %c32]


// Create DPAS computation loop over tiled reduction dimension.
// CHECK: %[[res:.+]]:11 = scf.for{{.*}}%c0 to %c1024 step %c16
// CHECK: %[[res:.+]]:13 = scf.for{{.*}}%c0 to %c1024 step %c16
// CHECK-SAME: iter_args(%[[acc:.+]] = %[[vC_f32]],{{.*}}%[[iterA:.+]] = %[[tA]],{{.*}}%[[iterB:.+]] = %[[tB]],{{.*}}%[[iterB1:.+]] = %[[tB1]]
// CHECK-SAME: {

Expand All @@ -66,10 +69,10 @@ module {

// Extract DPAS-sized chunks from larger loaded tile A.
// Tile B is already in the correct shape.
// CHECK: %[[vA_flat:.+]] = vector.shape_cast %[[vA]] : vector<32x16xf16> to vector<512xf16>
// CHECK: %[[vA_dpas_flat:.+]] = vector.extract_strided_slice{{.*}}: vector<512xf16> to vector<128xf16>
// CHECK: %[[vA_flat:.+]] = vector.shape_cast %[[vA]] : vector<16x16xf16> to vector<256xf16>
// CHECK: %[[vA_dpas_flat:.+]] = vector.extract_strided_slice{{.*}}: vector<256xf16> to vector<128xf16>
// CHECK: %[[vA_dpas:.+]] = vector.shape_cast %[[vA_dpas_flat]] : vector<128xf16> to vector<8x16xf16>
// CHECK-COUNT-3: vector.extract_strided_slice
// CHECK-COUNT-1: vector.extract_strided_slice

// Perform DPAS computation.
// CHECK: %[[dpas:.+]] = xegpu.dpas %[[vA_dpas]], %[[vB]], %[[acc]]
Expand Down