Skip to content

Commit cbbc931

Browse files
authored
[ROCM] Update bounds for large f16 data-tiling ukernel (#22481)
Signed-off-by: Jorn Tuyls <[email protected]>
1 parent 3bd426c commit cbbc931

File tree

2 files changed

+31
-2
lines changed

2 files changed

+31
-2
lines changed

compiler/plugins/target/ROCM/Dialect/ROCM/IR/ROCMAttrs.cpp

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,12 +81,21 @@ static bool checkIterationSizeConstraints(ArrayRef<int64_t> iterationSizes,
8181
if (indexVal < 0 || indexVal >= iterationSizes.size()) {
8282
return false;
8383
}
84+
// For now, assume a dynamic dimension is very large and any division
85+
// constraint is satisfied to keep the performance state on current models
86+
// (llama) as is.
87+
// TODO(#22370): This is not ideal and can be improved once we support value
88+
// bounds on dynamic dimensions for encodings.
8489
if (IntegerAttr sizeMin = constraint.getSizeMin()) {
85-
if (iterationSizes[indexVal] < sizeMin.getInt()) {
90+
if (ShapedType::isStatic(iterationSizes[indexVal]) &&
91+
iterationSizes[indexVal] < sizeMin.getInt()) {
8692
return false;
8793
}
8894
}
8995
if (IntegerAttr sizeMax = constraint.getSizeMax()) {
96+
if (ShapedType::isDynamic(iterationSizes[indexVal])) {
97+
return false;
98+
}
9099
if (iterationSizes[indexVal] > sizeMax.getInt()) {
91100
return false;
92101
}
@@ -95,6 +104,9 @@ static bool checkIterationSizeConstraints(ArrayRef<int64_t> iterationSizes,
95104
if (sizeDiv.getInt() <= 0) {
96105
return false;
97106
}
107+
if (ShapedType::isDynamic(iterationSizes[indexVal])) {
108+
return true;
109+
}
98110
if (iterationSizes[indexVal] % sizeDiv.getInt()) {
99111
return false;
100112
}

compiler/plugins/target/ROCM/builtins/mlir_ukernel/iree_uk_amdgpu_dt_matmul_f16.mlir

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,24 @@
1717
util.func @pingpong_dt_large_f16(%lhs_base: !lhs_base_ty, %rhs_base: !rhs_base_ty, %unused_acc: !acc_base_ty) -> !acc_base_ty attributes {
1818
ukernel_info = #rocm.ukernel_info<
1919
match = {
20-
types = [f16, f16, f32]
20+
types = [f16, f16, f32],
21+
iteration_sizes_constraints = [
22+
#rocm.ukernel_interation_size_constraint<
23+
index = 0,
24+
size_min = 512,
25+
size_div = 64
26+
>,
27+
#rocm.ukernel_interation_size_constraint<
28+
index = 1,
29+
size_min = 32832,
30+
size_div = 64
31+
>,
32+
#rocm.ukernel_interation_size_constraint<
33+
index = 2,
34+
size_min = 512,
35+
size_div = 64
36+
>
37+
]
2138
},
2239
mma = #iree_gpu.data_tiled_mma_layout<
2340
intrinsic = MFMA_F32_16x16x16_F16,

0 commit comments

Comments
 (0)