Skip to content

Commit fcae3fc

Browse files
authored
[ROCM][DT] Update ukernel data layout (#22350)
#22284 changes the data tiling layout by removing `moveCrossThreadOutermost`. This PR updates ukernel accordingly to ensure correct matching. Numerical correctness and performance have been verified locally on llama 8b prefill. Closes: #22349 --------- Signed-off-by: Yu-Zhewen <[email protected]>
1 parent cc164ae commit fcae3fc

File tree

4 files changed

+15
-15
lines changed

4 files changed

+15
-15
lines changed

compiler/plugins/target/ROCM/Dialect/ROCM/Transforms/test/apply_builtin_ukernel_pdl_patterns_driver.mlir

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -362,15 +362,15 @@ module attributes {
362362
module attributes {
363363
hal.executable.target = #executable_target_rocm_hsaco_fb
364364
} {
365-
func.func @inner_tiled_f8_large(%arg0: tensor<1x128x2x8x4x4x4x8xf8E4M3FNUZ>, %arg1: tensor<16x128x4x4x4x16x8xf8E4M3FNUZ>) -> tensor<1x16x2x4x8x4x4x16x4xf32> {
365+
func.func @inner_tiled_f8_large(%arg0: tensor<1x128x2x8x4x16x8xf8E4M3FNUZ>, %arg1: tensor<16x128x4x4x4x16x8xf8E4M3FNUZ>) -> tensor<1x16x2x4x8x4x4x16x4xf32> {
366366
%cst = arith.constant 0.000000e+00 : f32
367367
%0 = tensor.empty() : tensor<1x16x2x4x8x4x4x16x4xf32>
368368
%1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<1x16x2x4x8x4x4x16x4xf32>) -> tensor<1x16x2x4x8x4x4x16x4xf32>
369369
%2 = iree_codegen.inner_tiled ins(%arg0, %arg1) outs(%1){
370370
indexing_maps = [#map1, #map2, #map3],
371371
iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>],
372372
kind = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_F32_16x16x32_F8E4M3FNUZ, intrinsics_m = 8, subgroups_m = 2, intrinsics_n = 4, subgroups_n = 4>
373-
} : tensor<1x128x2x8x4x4x4x8xf8E4M3FNUZ>, tensor<16x128x4x4x4x16x8xf8E4M3FNUZ> into tensor<1x16x2x4x8x4x4x16x4xf32>
373+
} : tensor<1x128x2x8x4x16x8xf8E4M3FNUZ>, tensor<16x128x4x4x4x16x8xf8E4M3FNUZ> into tensor<1x16x2x4x8x4x4x16x4xf32>
374374
return %2 : tensor<1x16x2x4x8x4x4x16x4xf32>
375375
}
376376
}
@@ -396,15 +396,15 @@ module attributes {
396396
module attributes {
397397
hal.executable.target = #executable_target_rocm_hsaco_fb
398398
} {
399-
func.func @inner_tiled_f8_medium(%arg0: tensor<1x64x8x4x4x4x2x8xf8E4M3FNUZ>, %arg1: tensor<4x64x8x2x4x16x2x8xf8E4M3FNUZ>) -> tensor<1x4x8x8x2x4x16x4xf32> {
399+
func.func @inner_tiled_f8_medium(%arg0: tensor<1x64x8x4x16x2x8xf8E4M3FNUZ>, %arg1: tensor<4x64x8x2x4x16x2x8xf8E4M3FNUZ>) -> tensor<1x4x8x8x2x4x16x4xf32> {
400400
%cst = arith.constant 0.000000e+00 : f32
401401
%0 = tensor.empty() : tensor<1x4x8x8x2x4x16x4xf32>
402402
%1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<1x4x8x8x2x4x16x4xf32>) -> tensor<1x4x8x8x2x4x16x4xf32>
403403
%2 = iree_codegen.inner_tiled ins(%arg0, %arg1) outs(%1){
404404
indexing_maps = [#map1, #map2, #map3],
405405
iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>],
406406
kind = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_F32_16x16x32_F8E4M3FNUZ, intrinsics_m = 8, intrinsics_n = 2, subgroups_n = 8, intrinsics_k = 2>
407-
} : tensor<1x64x8x4x4x4x2x8xf8E4M3FNUZ>, tensor<4x64x8x2x4x16x2x8xf8E4M3FNUZ> into tensor<1x4x8x8x2x4x16x4xf32>
407+
} : tensor<1x64x8x4x16x2x8xf8E4M3FNUZ>, tensor<4x64x8x2x4x16x2x8xf8E4M3FNUZ> into tensor<1x4x8x8x2x4x16x4xf32>
408408
return %2 : tensor<1x4x8x8x2x4x16x4xf32>
409409
}
410410
}
@@ -430,15 +430,15 @@ module attributes {
430430
module attributes {
431431
hal.executable.target = #executable_target_rocm_hsaco_fb
432432
} {
433-
func.func @inner_tiled_f16_large(%arg0: tensor<1x256x2x8x4x4x4x4xf16>, %arg1: tensor<501x256x4x4x4x16x4xf16>) -> tensor<1x501x2x4x8x4x4x16x4xf32> {
433+
func.func @inner_tiled_f16_large(%arg0: tensor<1x256x2x8x4x16x4xf16>, %arg1: tensor<501x256x4x4x4x16x4xf16>) -> tensor<1x501x2x4x8x4x4x16x4xf32> {
434434
%cst = arith.constant 0.000000e+00 : f32
435435
%0 = tensor.empty() : tensor<1x501x2x4x8x4x4x16x4xf32>
436436
%1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<1x501x2x4x8x4x4x16x4xf32>) -> tensor<1x501x2x4x8x4x4x16x4xf32>
437437
%2 = iree_codegen.inner_tiled ins(%arg0, %arg1) outs(%1){
438438
indexing_maps = [#map1, #map2, #map3],
439439
iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>],
440440
kind = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_F32_16x16x16_F16, intrinsics_m = 8, subgroups_m = 2, intrinsics_n = 4, subgroups_n = 4>
441-
} : tensor<1x256x2x8x4x4x4x4xf16>, tensor<501x256x4x4x4x16x4xf16> into tensor<1x501x2x4x8x4x4x16x4xf32>
441+
} : tensor<1x256x2x8x4x16x4xf16>, tensor<501x256x4x4x4x16x4xf16> into tensor<1x501x2x4x8x4x4x16x4xf32>
442442
return %2 : tensor<1x501x2x4x8x4x4x16x4xf32>
443443
}
444444
}

compiler/plugins/target/ROCM/builtins/mlir_ukernel/iree_uk_amdgpu_dt_matmul_f16.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
// RUN: iree-opt %s
22

33
!acc_base_ty = tensor<1x1x2x4x8x4x4x16x4xf32>
4-
!lhs_base_ty = tensor<1x?x2x8x4x4x4x4xf16>
4+
!lhs_base_ty = tensor<1x?x2x8x4x16x4xf16>
55
!lhs_expand_ty = tensor<1x?x4x2x8x4x4x2x2x4xf16>
66
!rhs_base_ty = tensor<1x?x4x4x4x16x4xf16>
77
!rhs_expand_ty = tensor<1x?x4x4x4x4x8x2x4xf16>
@@ -44,7 +44,7 @@ util.func @pingpong_dt_large_f16(%lhs_base: !lhs_base_ty, %rhs_base: !rhs_base_t
4444
%dim = tensor.dim %rhs_base, %c1 : !rhs_base_ty
4545
%nDim = arith.divui %dim, %c4 : index
4646

47-
%lhs_expand = tensor.expand_shape %lhs_base [[0], [1, 2], [3], [4], [5], [6], [7, 8], [9]] output_shape [1, %nDim, 2, 2, 8, 4, 4, 2, 2, 4] : !lhs_base_ty into !lhs_expand_ty
47+
%lhs_expand = tensor.expand_shape %lhs_base [[0], [1, 2], [3], [4], [5], [6, 7, 8], [9]] output_shape [1, %nDim, 2, 2, 8, 4, 4, 2, 2, 4] : !lhs_base_ty into !lhs_expand_ty
4848
%rhs_expand = tensor.expand_shape %rhs_base [[0], [1, 2], [3], [4], [5], [6, 7], [8]] output_shape [1, %nDim, 2, 4, 4, 4, 8, 2, 4] : !rhs_base_ty into !rhs_expand_ty
4949

5050
%lhs = tensor.collapse_shape %lhs_expand [[0, 1], [2], [3, 4], [5, 6, 7], [8, 9]] : !lhs_expand_ty into !in_ty

compiler/plugins/target/ROCM/builtins/mlir_ukernel/iree_uk_amdgpu_dt_matmul_f8E4M3FNUZ.mlir

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
// RUN: iree-opt %s
22

33
!acc_base_ty = tensor<1x1x2x4x8x4x4x16x4xf32>
4-
!lhs_base_ty = tensor<1x?x2x8x4x4x4x8xf8E4M3FNUZ>
4+
!lhs_base_ty = tensor<1x?x2x8x4x16x8xf8E4M3FNUZ>
55
!lhs_expand_ty = tensor<1x?x4x2x8x4x4x2x2x8xf8E4M3FNUZ>
66
!rhs_base_ty = tensor<1x?x4x4x4x16x8xf8E4M3FNUZ>
77
!rhs_expand_ty = tensor<1x?x4x4x4x4x8x2x8xf8E4M3FNUZ>
88
!in_ty = tensor<?x4x16x32x16xf8E4M3FNUZ>
99
!shared_ty = memref<4x16x64x8xf8E4M3FNUZ, #gpu.address_space<workgroup>>
1010

1111
!m_acc_base_ty = tensor<1x1x8x8x2x4x16x4xf32>
12-
!m_lhs_base_ty = tensor<1x?x8x4x4x4x2x8xf8E4M3FNUZ>
12+
!m_lhs_base_ty = tensor<1x?x8x4x16x2x8xf8E4M3FNUZ>
1313
!m_lhs_expand_ty = tensor<1x?x2x8x4x4x4x2x8xf8E4M3FNUZ>
1414
!m_rhs_base_ty = tensor<1x?x8x2x4x16x2x8xf8E4M3FNUZ>
1515
!m_rhs_expand_ty = tensor<1x?x2x8x2x4x16x2x8xf8E4M3FNUZ>
@@ -61,7 +61,7 @@ util.func @pingpong_dt_large_f8E4M3FNUZ(%lhs_base: !lhs_base_ty, %rhs_base: !rhs
6161
%dim = tensor.dim %rhs_base, %c1 : !rhs_base_ty
6262
%nDim = arith.divui %dim, %c4 : index
6363

64-
%lhs_expand = tensor.expand_shape %lhs_base [[0], [1, 2], [3], [4], [5], [6], [7, 8], [9]] output_shape [1, %nDim, 4, 2, 8, 4, 4, 2, 2, 8] : !lhs_base_ty into !lhs_expand_ty
64+
%lhs_expand = tensor.expand_shape %lhs_base [[0], [1, 2], [3], [4], [5], [6, 7, 8], [9]] output_shape [1, %nDim, 4, 2, 8, 4, 4, 2, 2, 8] : !lhs_base_ty into !lhs_expand_ty
6565
%rhs_expand = tensor.expand_shape %rhs_base [[0], [1, 2], [3], [4], [5], [6, 7], [8]] output_shape [1, %nDim, 4, 4, 4, 4, 8, 2, 8] : !rhs_base_ty into !rhs_expand_ty
6666

6767
%lhs = tensor.collapse_shape %lhs_expand [[0, 1], [2], [3, 4], [5, 6, 7], [8, 9]] : !lhs_expand_ty into !in_ty
@@ -319,7 +319,7 @@ util.func private @pingpong_dt_medium_f8E4M3FNUZ(%lhs_base: !m_lhs_base_ty, %rhs
319319
%dim = tensor.dim %rhs_base, %c1 : !m_rhs_base_ty
320320
%nDim = arith.divui %dim, %c2 : index
321321

322-
%lhs_expand = tensor.expand_shape %lhs_base [[0], [1, 2], [3], [4], [5], [6], [7], [8]] output_shape [1, %nDim, 2, 8, 4, 4, 4, 2, 8] : !m_lhs_base_ty into !m_lhs_expand_ty
322+
%lhs_expand = tensor.expand_shape %lhs_base [[0], [1, 2], [3], [4], [5, 6], [7], [8]] output_shape [1, %nDim, 2, 8, 4, 4, 4, 2, 8] : !m_lhs_base_ty into !m_lhs_expand_ty
323323
%rhs_expand = tensor.expand_shape %rhs_base [[0], [1, 2], [3], [4], [5], [6], [7], [8]] output_shape [1, %nDim, 2, 8, 2, 4, 16, 2, 8] : !m_rhs_base_ty into !m_rhs_expand_ty
324324

325325
%lhs = tensor.collapse_shape %lhs_expand [[0, 1], [2], [3], [4, 5, 6], [7, 8]] : !m_lhs_expand_ty into !m_lhs_ty

compiler/plugins/target/ROCM/builtins/mlir_ukernel/ukernel_patterns_gfx942.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -720,7 +720,7 @@ pdl.pattern @annotate_inner_tiled_f8E4M3FNUZ_medium : benefit(1) {
720720
%attr_name = pdl.attribute = "iree_codegen.ukernel"
721721
pdl.apply_native_constraint "hasAttr"(%generic_op, %attr_name : !pdl.operation, !pdl.attribute) {isNegated = true}
722722

723-
%lhs_cast_type = pdl.type : tensor<?x?x8x4x4x4x2x8xf8E4M3FNUZ>
723+
%lhs_cast_type = pdl.type : tensor<?x?x8x4x16x2x8xf8E4M3FNUZ>
724724
pdl.apply_native_constraint "matchCastCompatibleType"(%lhs, %lhs_cast_type : !pdl.value, !pdl.type)
725725
%rhs_cast_type = pdl.type : tensor<?x?x8x2x4x16x2x8xf8E4M3FNUZ>
726726
pdl.apply_native_constraint "matchCastCompatibleType"(%rhs, %rhs_cast_type : !pdl.value, !pdl.type)
@@ -777,7 +777,7 @@ pdl.pattern @annotate_inner_tiled_f8E4M3FNUZ_large : benefit(2) {
777777
%attr_name = pdl.attribute = "iree_codegen.ukernel"
778778
pdl.apply_native_constraint "hasAttr"(%generic_op, %attr_name : !pdl.operation, !pdl.attribute) {isNegated = true}
779779

780-
%lhs_cast_type = pdl.type : tensor<?x?x2x8x4x4x4x8xf8E4M3FNUZ>
780+
%lhs_cast_type = pdl.type : tensor<?x?x2x8x4x16x8xf8E4M3FNUZ>
781781
pdl.apply_native_constraint "matchCastCompatibleType"(%lhs, %lhs_cast_type : !pdl.value, !pdl.type)
782782
%rhs_cast_type = pdl.type : tensor<?x?x4x4x4x16x8xf8E4M3FNUZ>
783783
pdl.apply_native_constraint "matchCastCompatibleType"(%rhs, %rhs_cast_type : !pdl.value, !pdl.type)
@@ -834,7 +834,7 @@ pdl.pattern @annotate_inner_tiled_f16_large : benefit(1) {
834834
%attr_name = pdl.attribute = "iree_codegen.ukernel"
835835
pdl.apply_native_constraint "hasAttr"(%generic_op, %attr_name : !pdl.operation, !pdl.attribute) {isNegated = true}
836836

837-
%lhs_cast_type = pdl.type : tensor<?x?x2x8x4x4x4x4xf16>
837+
%lhs_cast_type = pdl.type : tensor<?x?x2x8x4x16x4xf16>
838838
pdl.apply_native_constraint "matchCastCompatibleType"(%lhs, %lhs_cast_type : !pdl.value, !pdl.type)
839839
%rhs_cast_type = pdl.type : tensor<?x?x4x4x4x16x4xf16>
840840
pdl.apply_native_constraint "matchCastCompatibleType"(%rhs, %rhs_cast_type : !pdl.value, !pdl.type)

0 commit comments

Comments
 (0)